From d8e51f9dd21cdffb5f8eb1f6312b761529dbcb9b Mon Sep 17 00:00:00 2001 From: Ken Date: Tue, 8 Jul 2014 18:31:41 -0700 Subject: [PATCH 001/347] initial commit for pySparkStreaming --- bin/spark-submit | 6 + core/pom.xml | 2 +- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../apache/spark/deploy/PythonRunner.scala | 1 + .../src/main/python/streaming/wordcount.py | 22 ++ python/pyspark/java_gateway.py | 3 + python/pyspark/streaming/__init__.py | 1 + python/pyspark/streaming/context.py | 133 ++++++++ python/pyspark/streaming/dstream.py | 315 ++++++++++++++++++ python/pyspark/streaming/duration.py | 171 ++++++++++ python/pyspark/streaming/jtime.py | 116 +++++++ python/pyspark/streaming/pyprint.py | 28 ++ python/pyspark/streaming/utils.py | 18 + streaming/pom.xml | 14 +- .../streaming/api/java/JavaDStreamLike.scala | 8 + .../streaming/api/python/PythonDStream.scala | 152 +++++++++ .../spark/streaming/dstream/DStream.scala | 68 +++- 17 files changed, 1050 insertions(+), 10 deletions(-) create mode 100644 examples/src/main/python/streaming/wordcount.py create mode 100644 python/pyspark/streaming/__init__.py create mode 100644 python/pyspark/streaming/context.py create mode 100644 python/pyspark/streaming/dstream.py create mode 100644 python/pyspark/streaming/duration.py create mode 100644 python/pyspark/streaming/jtime.py create mode 100644 python/pyspark/streaming/pyprint.py create mode 100644 python/pyspark/streaming/utils.py create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala diff --git a/bin/spark-submit b/bin/spark-submit index 9e7cecedd0325..ac275b7696d5c 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -37,6 +37,12 @@ done DEPLOY_MODE=${DEPLOY_MODE:-"client"} +# Figure out which Python executable to use +if [[ -z "$PYSPARK_PYTHON" ]]; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + if [ -n "$DRIVER_MEMORY" ] && [ $DEPLOY_MODE == "client" ]; then export SPARK_DRIVER_MEMORY=$DRIVER_MEMORY fi diff --git a/core/pom.xml b/core/pom.xml index 8c23842730e37..43633dcb63f54 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.0.0 ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6570d335757a..e88a54d2086ea 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -252,7 +252,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends +private[spark] class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions override def compute(split: Partition, context: TaskContext) = diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0d6751f3fa6d2..89f3fd47724fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -57,6 +57,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + env.put("PYSPARK_PYTHON", pythonExec) env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/wordcount.py new file mode 100644 index 0000000000000..f44cd696894ba --- /dev/null +++ b/examples/src/main/python/streaming/wordcount.py @@ -0,0 +1,22 @@ +import sys +from operator import add + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.duration import * + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: wordcount " + exit(-1) + ssc = StreamingContext(appName="PythonStreamingWordCount", duration=Seconds(1)) + + lines = ssc.textFileStream(sys.argv[1]) + fm_lines = lines.flatMap(lambda x: x.split(" ")) + filtered_lines = fm_lines.filter(lambda line: "Spark" in line) + mapped_lines = fm_lines.map(lambda x: (x, 1)) + + fm_lines.pyprint() + filtered_lines.pyprint() + mapped_lines.pyprint() + ssc.start() + ssc.awaitTermination() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0dbead4415b02..7038c6422be47 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -82,6 +82,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.streaming.*") + java_import(gateway.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gateway.jvm, "org.apache.spark.streaming.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..719592912e80c --- /dev/null +++ b/python/pyspark/streaming/__init__.py @@ -0,0 +1 @@ +__author__ = 'ktakagiw' diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..c8ae9c4af85c9 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,133 @@ +__author__ = 'ktakagiw' + + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import sys +from threading import Lock +from tempfile import NamedTemporaryFile + +from pyspark import accumulators +from pyspark.accumulators import Accumulator +from pyspark.broadcast import Broadcast +from pyspark.conf import SparkConf +from pyspark.files import SparkFiles +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.storagelevel import StorageLevel +from pyspark.rdd import RDD +from pyspark.context import SparkContext + +from py4j.java_collections import ListConverter + +from pyspark.streaming.dstream import DStream + +class StreamingContext(object): + """ + Main entry point for Spark functionality. A StreamingContext represents the + connection to a Spark cluster, and can be used to create L{RDD}s and + broadcast variables on that cluster. + """ + + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + gateway=None, duration=None): + """ + Create a new StreamingContext. At least the master and app name and duration + should be set, either through the named parameters here or through C{conf}. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param appName: A name for your job, to display on the cluster web UI. + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + @param serializer: The serializer for RDDs. + @param conf: A L{SparkConf} object setting Spark properties. + @param gateway: Use an existing gateway and JVM, otherwise a new JVM + will be instatiated. + @param duration: A L{Duration} Duration for SparkStreaming + + """ + # Create the Python Sparkcontext + self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome, + pyFiles=pyFiles, environment=environment, batchSize=batchSize, + serializer=serializer, conf=conf, gateway=gateway) + self._jvm = self._sc._jvm + self._jssc = self._initialize_context(self._sc._jsc, duration._jduration) + + # Initialize StremaingContext in function to allow subclass specific initialization + def _initialize_context(self, jspark_context, jduration): + return self._jvm.JavaStreamingContext(jspark_context, jduration) + + def actorStream(self, props, name, storageLevel, supervisorStrategy): + raise NotImplementedError + + def addStreamingListener(self, streamingListener): + raise NotImplementedError + + def awaitTermination(self, timeout=None): + if timeout: + self._jssc.awaitTermination(timeout) + else: + self._jssc.awaitTermination() + + def checkpoint(self, directory): + raise NotImplementedError + + def fileStream(self, directory, filter=None, newFilesOnly=None): + raise NotImplementedError + + def networkStream(self, receiver): + raise NotImplementedError + + def queueStream(self, queue, oneAtATime=True, defaultRDD=None): + raise NotImplementedError + + def rawSocketStream(self, hostname, port, storagelevel): + raise NotImplementedError + + def remember(self, duration): + raise NotImplementedError + + def socketStream(hostname, port, converter,storageLevel): + raise NotImplementedError + + def start(self): + self._jssc.start() + + def stop(self, stopSparkContext=True): + raise NotImplementedError + + def textFileStream(self, directory): + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def transform(self, seq): + raise NotImplementedError + + def union(self, seq): + raise NotImplementedError + diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..b422b147d11e1 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,315 @@ +from base64 import standard_b64encode as b64enc +import copy +from collections import defaultdict +from collections import namedtuple +from itertools import chain, ifilter, imap +import operator +import os +import sys +import shlex +import traceback +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread +import warnings +import heapq +from random import Random + +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup +from pyspark.statcounter import StatCounter +from pyspark.rddsampler import RDDSampler +from pyspark.storagelevel import StorageLevel +#from pyspark.resultiterable import ResultIterable +from pyspark.rdd import _JavaStackTrace + +from py4j.java_collections import ListConverter, MapConverter + +__all__ = ["DStream"] + +class DStream(object): + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self.ctx = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + + def generatedRDDs(self): + """ + // RDDs generated, marked as private[streaming] so that testsuites can access it + @transient + """ + pass + + def print_(self): + """ + """ + # print is a resrved name of Python. We cannot give print to function name + getattr(self._jdstream, "print")() + + def pyprint(self): + """ + """ + self._jdstream.pyprint() + + def cache(self): + """ + """ + raise NotImplementedError + + def checkpoint(self): + """ + """ + raise NotImplementedError + + def compute(self, time): + """ + """ + raise NotImplementedError + + def context(self): + """ + """ + raise NotImplementedError + + def count(self): + """ + """ + raise NotImplementedError + + def countByValue(self, numPartitions=None): + """ + """ + raise NotImplementedError + + def countByValueAndWindow(self, duration, slideDuration=None): + """ + """ + raise NotImplementedError + + def countByWindow(self, duration, slideDuration=None): + """ + """ + raise NotImplementedError + + def dstream(self): + """ + """ + raise NotImplementedError + + def filter(self, f): + """ + """ + def func(iterator): return ifilter(f, iterator) + return self.mapPartitions(func) + + def flatMap(self, f, preservesPartitioning=False): + """ + """ + def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def foreachRDD(self, f, time): + """ + """ + raise NotImplementedError + + def glom(self): + """ + """ + raise NotImplementedError + + def map(self, f, preservesPartitioning=False): + """ + """ + def func(split, iterator): return imap(f, iterator) + return PipelinedDStream(self, func, preservesPartitioning) + + def mapPartitions(self, f): + """ + """ + def func(s, iterator): return f(iterator) + return self.mapPartitionsWithIndex(func) + + def perist(self, storageLevel): + """ + """ + raise NotImplementedError + + def reduce(self, func, numPartitions=None): + """ + + """ + return self._combineByKey(lambda x:x, func, func, numPartitions) + + def _combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions = None): + """ + """ + if numPartitions is None: + numPartitions = self.ctx._defaultParallelism() + def combineLocally(iterator): + combiners = {} + for x in iterator: + (k, v) = x + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numPartitions) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) + + + def partitionBy(self, numPartitions, partitionFunc=None): + """ + Return a copy of the DStream partitioned using the specified partitioner. + + """ + if numPartitions is None: + numPartitions = self.ctx._defaultReducePartitions() + + if partitionFunc is None: + partitionFunc = lambda x: 0 if x is None else hash(x) + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numPartitions) objects + # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer + def add_shuffle_key(split, iterator): + + buckets = defaultdict(list) + + for (k, v) in iterator: + buckets[partitionFunc(k) % numPartitions].append((k, v)) + for (split, items) in buckets.iteritems(): + yield pack_long(split) + yield outputSerializer.dumps(items) + keyed = PipelinedDStream(self, add_shuffle_key) + keyed._bypass_serializer = True + with _JavaStackTrace(self.ctx) as st: + #JavaDStream + #pairRDD = self.ctx._jvm.PairwiseDStream(keyed._jdstream.dstream()).asJavaPairRDD() + pairDStream = self.ctx._jvm.PairwiseDStream(keyed._jdstream.dstream()).asJavaPairDStream() + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) + jdstream = pairDStream.partitionBy(partitioner).values() + dstream = DStream(jdstream, self._ssc, BatchedSerializer(outputSerializer)) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + dstream._partitionFunc = partitionFunc + return dstream + + + + def reduceByWindow(self, reduceFunc, windowDuration, slideDuration, inReduceTunc): + """ + """ + + raise NotImplementedError + + def repartition(self, numPartitions): + """ + """ + raise NotImplementedError + + def slice(self, fromTime, toTime): + """ + """ + raise NotImplementedError + + def transform(self, transformFunc): + """ + """ + raise NotImplementedError + + def transformWith(self, other, transformFunc): + """ + """ + raise NotImplementedError + + def union(self, that): + """ + """ + raise NotImplementedError + + def window(self, windowDuration, slideDuration=None): + """ + """ + raise NotImplementedError + + def wrapRDD(self, rdd): + """ + """ + raise NotImplementedError + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + return PipelinedDStream(self, f, preservesPartitioning) + + +class PipelinedDStream(DStream): + def __init__(self, prev, func, preservesPartitioning=False): + if not isinstance(prev, PipelinedDStream) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jdstream = prev._jdstream + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: + prev_func = prev.func + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) + self.func = pipeline_func + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jdstream = prev._prev_jdstream # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + self._ssc = prev._ssc + self.ctx = prev.ctx + self.prev = prev + self._jdstream_val = None + self._jrdd_deserializer = self.ctx.serializer + self._bypass_serializer = False + + @property + def _jdstream(self): + if self._jdstream_val: + return self._jdstream_val + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx._gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_tag = self._prev_jdstream.classTag() + env = MapConverter().convert(self.ctx.environment, + self.ctx._gateway._gateway_client) + includes = ListConverter().convert(self.ctx._python_includes, + self.ctx._gateway._gateway_client) + python_dstream = self.ctx._jvm.PythonDStream(self._prev_jdstream.dstream(), + bytearray(pickled_command), + env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_tag) + self._jdstream_val = python_dstream.asJavaDStream() + return self._jdstream_val + + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/streaming/duration.py b/python/pyspark/streaming/duration.py new file mode 100644 index 0000000000000..ef1b4f6cef237 --- /dev/null +++ b/python/pyspark/streaming/duration.py @@ -0,0 +1,171 @@ +__author__ = 'ktakagiw' + +from pyspark.streaming import utils + +class Duration(object): + """ + Duration for Spark Streaming application. Used to set duration + + Most of the time, you would create a Duration object with + C{Duration()}, which will load values from C{spark.streaming.*} Java system + properties as well. In this case, any parameters you set directly on + the C{Duration} object take priority over system properties. + + """ + def __init__(self, millis, _jvm=None): + """ + Create new Duration. + + @param millis: milisecond + + """ + self._millis = millis + + from pyspark.context import SparkContext + SparkContext._ensure_initialized() + _jvm = _jvm or SparkContext._jvm + self._jduration = _jvm.Duration(millis) + + def toString(self): + """ Return duration as string """ + return str(self._millis) + " ms" + + def isZero(self): + """ Check if millis is zero """ + return self._millis == 0 + + def prettyPrint(self): + """ + Return a human-readable string representing a duration + """ + return utils.msDurationToString(self._millis) + + def milliseconds(self): + """ Return millisecond """ + return self._millis + + def toFormattedString(self): + """ Return millisecond """ + return str(self._millis) + + def max(self, other): + """ Return higher Duration """ + Duration._is_duration(other) + if self > other: + return self + else: + return other + + def min(self, other): + """ Return lower Durattion """ + Duration._is_duration(other) + if self < other: + return self + else: + return other + + def __str__(self): + return self.toString() + + def __add__(self, other): + """ Add Duration and Duration """ + Duration._is_duration(other) + return Duration(self._millis + other._millis) + + def __sub__(self, other): + """ Subtract Duration by Duration """ + Duration._is_duration(other) + return Duration(self._millis - other._millis) + + def __mul__(self, other): + """ Multiple Duration by Duration """ + Duration._is_duration(other) + return Duration(self._millis * other._millis) + + def __div__(self, other): + """ + Divide Duration by Duration + for Python 2.X + """ + Duration._is_duration(other) + return Duration(self._millis / other._millis) + + def __truediv__(self, other): + """ + Divide Duration by Duration + for Python 3.0 + """ + Duration._is_duration(other) + return Duration(self._millis / other._millis) + + def __floordiv__(self, other): + """ Divide Duration by Duration """ + Duration._is_duration(other) + return Duration(self._millis // other._millis) + + def __len__(self): + """ Length of miilisecond in Duration """ + return len(self._millis) + + def __lt__(self, other): + """ Duration < Duration """ + Duration._is_duration(other) + return self._millis < other._millis + + def __le__(self, other): + """ Duration <= Duration """ + Duration._is_duration(other) + return self.millis <= other._millis + + def __eq__(self, other): + """ Duration == Duration """ + Duration._is_duration(other) + return self._millis == other._millis + + def __ne__(self, other): + """ Duration != Duration """ + Duration._is_duration(other) + return self._millis != other._millis + + def __gt__(self, other): + """ Duration > Duration """ + Duration._is_duration(other) + return self._millis > other._millis + + def __ge__(self, other): + """ Duration >= Duration """ + Duration._is_duration(other) + return self._millis >= other._millis + + @classmethod + def _is_duration(self, instance): + """ is instance Duration """ + if not isinstance(instance, Duration): + raise TypeError("This should be Duration") + +def Milliseconds(milliseconds): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of milliseconds. + """ + return Duration(milliseconds) + +def Seconds(seconds): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of seconds. + """ + return Duration(seconds * 1000) + +def Minites(minites): + """ + Helper function that creates instance of [[pysparkstreaming.duration]] representing + a given number of minutes. + """ + return Duration(minutes * 60000) + +if __name__ == "__main__": + d = Duration(1) + print d + print d.milliseconds() + diff --git a/python/pyspark/streaming/jtime.py b/python/pyspark/streaming/jtime.py new file mode 100644 index 0000000000000..41670af659ea3 --- /dev/null +++ b/python/pyspark/streaming/jtime.py @@ -0,0 +1,116 @@ +__author__ = 'ktakagiw' + +from pyspark.streaming import utils +from pyspark.streaming.duration import Duration + +class Time(object): + """ + Time for Spark Streaming application. Used to set Time + + Most of the time, you would create a Duration object with + C{Time()}, which will load values from C{spark.streaming.*} Java system + properties as well. In this case, any parameters you set directly on + the C{Time} object take priority over system properties. + + """ + def __init__(self, millis, _jvm=None): + """ + Create new Time. + + @param millis: milisecond + + @param _jvm: internal parameter used to pass a handle to the + Java VM; does not need to be set by users + + """ + self._millis = millis + + from pyspark.context import StreamingContext + StreamingContext._ensure_initialized() + _jvm = _jvm or StreamingContext._jvm + self._jtime = _jvm.Time(millis) + + def toString(self): + """ Return time as string """ + return str(self._millis) + " ms" + + def milliseconds(self): + """ Return millisecond """ + return self._millis + + def max(self, other): + """ Return higher Time """ + Time._is_time(other) + if self > other: + return self + else: + return other + + def min(self, other): + """ Return lower Time """ + Time._is_time(other) + if self < other: + return self + else: + return other + + def __add__(self, other): + """ Add Time and Time """ + Duration._is_duration(other) + return Time(self._millis + other._millis) + + def __sub__(self, other): + """ Subtract Time by Duration or Time """ + if isinstance(other, Duration): + return Time(self._millis - other._millis) + elif isinstance(other, Time): + return Duration(self._mills, other._millis) + else: + raise TypeError + + def __lt__(self, other): + """ Time < Time """ + Time._is_time(other) + return self._millis < other._millis + + def __le__(self, other): + """ Time <= Time """ + Time._is_time(other) + return self.millis <= other._millis + + def __eq__(self, other): + """ Time == Time """ + Time._is_time(other) + return self._millis == other._millis + + def __ne__(self, other): + """ Time != Time """ + Time._is_time(other) + return self._millis != other._millis + + def __gt__(self, other): + """ Time > Time """ + Time._is_time(other) + return self._millis > other._millis + + def __ge__(self, other): + """ Time >= Time """ + Time._is_time(other) + return self._millis >= other._millis + + def isMultipbleOf(duration): + """ is multiple by Duration """ + Duration._is_duration(duration) + return self._millis % duration._millis == 0 + + def until(time, interval): + raise NotImplementedError + + def to(time, interval): + raise NotImplementedError + + @classmethod + def _is_time(self, instance): + """ is instance Time """ + if not isinstance(instance, Time): + raise TypeError diff --git a/python/pyspark/streaming/pyprint.py b/python/pyspark/streaming/pyprint.py new file mode 100644 index 0000000000000..fcdaca510812c --- /dev/null +++ b/python/pyspark/streaming/pyprint.py @@ -0,0 +1,28 @@ +import sys +from itertools import chain +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer + +def collect(binary_file_path): + dse = PickleSerializer() + with open(binary_file_path, 'rb') as tempFile: + for item in dse.load_stream(tempFile): + yield item +def main(): + try: + binary_file_path = sys.argv[1] + except: + print "Missed FilePath in argement" + + if not binary_file_path: + return + + counter = 0 + for rdd in chain.from_iterable(collect(binary_file_path)): + print rdd + counter = counter + 1 + if counter >= 10: + print "..." + break + +if __name__ =="__main__": + exit(main()) diff --git a/python/pyspark/streaming/utils.py b/python/pyspark/streaming/utils.py new file mode 100644 index 0000000000000..71aa3376c6578 --- /dev/null +++ b/python/pyspark/streaming/utils.py @@ -0,0 +1,18 @@ +__author__ = 'ktakagiw' + +def msDurationToString(ms): + """ + Returns a human-readable string representing a duration such as "35ms" + """ + second = 1000 + minute = 60 * second + hour = 60 * minute + + if ms < second: + return "%d ms" % ms + elif ms < minute: + return "%.1f s" % (float(ms) / second) + elif ms < hout: + return "%.1f m" % (float(ms) / minute) + else: + return "%.2f h" % (float(ms) / hour) diff --git a/streaming/pom.xml b/streaming/pom.xml index f506d6ce34a6f..88df63592efee 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.0.0 ../pom.xml @@ -69,14 +69,14 @@ org.scalatest scalatest-maven-plugin - - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..cfa336df8674f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -54,6 +54,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.print() } + /** + * Print the first ten elements of each PythonRDD generated in the PythonDStream. This is an output + * operator, so this PythonDStream will be registered as an output stream and there materialized. + * This function is for PythonAPI. + */ + + def pyprint() = dstream.pyprint() + /** * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..2d8b1e468dc4c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} + +import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark._ +import org.apache.spark.util.Utils +import java.io._ +import scala.Some +import org.apache.spark.streaming.Duration +import scala.util.control.Breaks._ +import org.apache.spark.broadcast.Broadcast +import scala.Some +import org.apache.spark.streaming.Duration +import org.apache.spark.rdd.RDD +import org.apache.spark.api.python.PythonRDD + + +import org.apache.spark.streaming.{Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ +import org.apache.spark.rdd.RDD +import org.apache.spark.api.python._ +import org.apache.spark.api.python.PairwiseRDD + + +import scala.reflect.ClassTag + + +class PythonDStream[T: ClassTag]( + parent: DStream[T], + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]] + ) extends DStream[Array[Byte]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + //pythonDStream compute + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + parent.getOrCompute(validTime) match{ + case Some(rdd) => + val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator) + Some(pythonRDD.asJavaRDD.rdd) + case None => None + } + } + val asJavaDStream = JavaDStream.fromDStream(this) + + /** + * Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output + * operator, so this PythonDStream will be registered as an output stream and there materialized. + * Since serialized Python object is readable by Python, pyprint writes out binary data to + * temporary file and run python script to deserialized and print the first ten elements + */ + private[streaming] def ppyprint() { + def foreachFunc = (rdd: RDD[Array[Byte]], time: Time) => { + val iter = rdd.take(11).iterator + + // make a temporary file + val prefix = "spark" + val suffix = ".tmp" + val tempFile = File.createTempFile(prefix, suffix) + val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath)) + //write out serialized python object + PythonRDD.writeIteratorToStream(iter, tempFileStream) + tempFileStream.close() + + // This value has to be passed from python + val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + //val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath())) // why this fails to compile??? + //absolute path to the python script is needed to change because we do not use pysparkstreaming + val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pysparkstreaming/streaming/pyprint.py", tempFile.getAbsolutePath) + val workerEnv = pb.environment() + + //envVars also need to be pass + //workerEnv.putAll(envVars) + val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") + workerEnv.put("PYTHONPATH", pythonPath) + val worker = pb.start() + val is = worker.getInputStream() + val isr = new InputStreamReader(is) + val br = new BufferedReader(isr) + + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + + //print value from python std out + var line = "" + breakable { + while (true) { + line = br.readLine() + if (line == null) break() + println(line) + } + } + //delete temporary file + tempFile.delete() + println() + + } + new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + } +} + + +private class PairwiseDStream(prev:DStream[Array[Byte]]) extends +DStream[(Long, Array[Byte])](prev.ssc){ + override def dependencies = List(prev) + + override def slideDuration: Duration = prev.slideDuration + + override def compute(validTime:Time):Option[RDD[(Long, Array[Byte])]]={ + prev.getOrCompute(validTime) match{ + case Some(rdd)=>Some(rdd) + val pairwiseRDD = new PairwiseRDD(rdd) + Some(pairwiseRDD.asJavaPairRDD.rdd) + case None => None + } + } + val asJavaPairDStream : JavaPairDStream[Long, Array[Byte]] = JavaPairDStream(this) +} + + + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 4709a62381647..ffd7f88fd9dd1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -18,11 +18,13 @@ package org.apache.spark.streaming.dstream -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.util.control.Breaks._ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} @@ -31,6 +33,8 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.scheduler.Job import org.apache.spark.util.MetadataCleaner +import org.apache.spark.streaming.Duration +import org.apache.spark.api.python.PythonRDD /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -601,6 +605,68 @@ abstract class DStream[T: ClassTag] ( new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } + + + + + /** + * Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output + * operator, so this PythonDStream will be registered as an output stream and there materialized. + * Since serialized Python object is readable by Python, pyprint writes out binary data to + * temporary file and run python script to deserialized and print the first ten elements + */ + private[streaming] def pyprint() { + def foreachFunc = (rdd: RDD[T], time: Time) => { + val iter = rdd.take(11).iterator + + // make a temporary file + val prefix = "spark" + val suffix = ".tmp" + val tempFile = File.createTempFile(prefix, suffix) + val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath)) + //write out serialized python object + PythonRDD.writeIteratorToStream(iter, tempFileStream) + tempFileStream.close() + + // This value has to be passed from python + val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + //val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath())) // why this fails to compile??? + //absolute path to the python script is needed to change because we do not use pysparkstreaming + val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath) + val workerEnv = pb.environment() + + //envVars also need to be pass + //workerEnv.putAll(envVars) + val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") + workerEnv.put("PYTHONPATH", pythonPath) + val worker = pb.start() + val is = worker.getInputStream() + val isr = new InputStreamReader(is) + val br = new BufferedReader(isr) + + println ("-------------------------------------------") + println ("Time: " + time) + println ("-------------------------------------------") + + //print value from python std out + var line = "" + breakable { + while (true) { + line = br.readLine() + if (line == null) break() + println(line) + } + } + //delete temporary file + tempFile.delete() + println() + + } + new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + } + + /** * Return a new DStream in which each RDD contains all the elements in seen in a * sliding window of time over this DStream. The new DStream generates RDDs with From 1367be52f80ee55a1b0cb1070b8fb02cf258c0be Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Tue, 15 Jul 2014 15:41:52 -0700 Subject: [PATCH 002/347] comment PythonDStream.PairwiseDStream --- .../apache/spark/streaming/api/python/PythonDStream.scala | 3 ++- .../scala/org/apache/spark/streaming/dstream/DStream.scala | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 2d8b1e468dc4c..fe67250604d8e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -129,7 +129,7 @@ class PythonDStream[T: ClassTag]( } } - +/* private class PairwiseDStream(prev:DStream[Array[Byte]]) extends DStream[(Long, Array[Byte])](prev.ssc){ override def dependencies = List(prev) @@ -146,6 +146,7 @@ DStream[(Long, Array[Byte])](prev.ssc){ } val asJavaPairDStream : JavaPairDStream[Long, Array[Byte]] = JavaPairDStream(this) } +*/ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b24109074e816..d9d5446b62e9f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -620,10 +620,7 @@ abstract class DStream[T: ClassTag] ( new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } - - - - +//TODO move pyprint to PythonDStream /** * Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output * operator, so this PythonDStream will be registered as an output stream and there materialized. @@ -644,6 +641,7 @@ abstract class DStream[T: ClassTag] ( tempFileStream.close() // This value has to be passed from python + // Python currently does not do cluster deployment. But what happened val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") //val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath())) // why this fails to compile??? From 88068cf8439991b17c244d65af3192b49968583f Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Tue, 15 Jul 2014 17:19:20 -0700 Subject: [PATCH 003/347] modify dstream.py to fix indent error --- python/pyspark/streaming/dstream.py | 2 +- .../org/apache/spark/streaming/api/python/PythonDStream.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index b422b147d11e1..a512517f6e437 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -172,7 +172,7 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners) - def partitionBy(self, numPartitions, partitionFunc=None): + def partitionBy(self, numPartitions, partitionFunc=None): """ Return a copy of the DStream partitioned using the specified partitioner. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index fe67250604d8e..389136f9e21a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -91,7 +91,7 @@ class PythonDStream[T: ClassTag]( tempFileStream.close() // This value has to be passed from python - val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") + //val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") //val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath())) // why this fails to compile??? //absolute path to the python script is needed to change because we do not use pysparkstreaming From 94a07879007d6e6157b7f5b59a04284996f5623f Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Tue, 15 Jul 2014 21:08:43 -0700 Subject: [PATCH 004/347] added reducedByKey not working yet --- .../src/main/python/streaming/wordcount.py | 10 ++++++- python/pyspark/streaming/dstream.py | 27 +++++++++++++++++-- .../streaming/api/python/PythonDStream.scala | 6 ++--- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/wordcount.py index f44cd696894ba..3996991109d60 100644 --- a/examples/src/main/python/streaming/wordcount.py +++ b/examples/src/main/python/streaming/wordcount.py @@ -1,6 +1,7 @@ import sys from operator import add +from pyspark.conf import SparkConf from pyspark.streaming.context import StreamingContext from pyspark.streaming.duration import * @@ -8,15 +9,22 @@ if len(sys.argv) != 2: print >> sys.stderr, "Usage: wordcount " exit(-1) - ssc = StreamingContext(appName="PythonStreamingWordCount", duration=Seconds(1)) + conf = SparkConf() + conf.setAppName("PythonStreamingWordCount") + conf.set("spark.default.parallelism", 1) + +# ssc = StreamingContext(appName="PythonStreamingWordCount", duration=Seconds(1)) + ssc = StreamingContext(conf=conf, duration=Seconds(1)) lines = ssc.textFileStream(sys.argv[1]) fm_lines = lines.flatMap(lambda x: x.split(" ")) filtered_lines = fm_lines.filter(lambda line: "Spark" in line) mapped_lines = fm_lines.map(lambda x: (x, 1)) + reduced_lines = mapped_lines.reduce(add) fm_lines.pyprint() filtered_lines.pyprint() mapped_lines.pyprint() + reduced_lines.pyprint() ssc.start() ssc.awaitTermination() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index a512517f6e437..e144f8bc1cc09 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -29,6 +29,7 @@ __all__ = ["DStream"] + class DStream(object): def __init__(self, jdstream, ssc, jrdd_deserializer): self._jdstream = jdstream @@ -149,7 +150,7 @@ def _combineByKey(self, createCombiner, mergeValue, mergeCombiners, """ """ if numPartitions is None: - numPartitions = self.ctx._defaultParallelism() + numPartitions = self._defaultReducePartitions() def combineLocally(iterator): combiners = {} for x in iterator: @@ -211,7 +212,6 @@ def add_shuffle_key(split, iterator): return dstream - def reduceByWindow(self, reduceFunc, windowDuration, slideDuration, inReduceTunc): """ """ @@ -254,8 +254,31 @@ def wrapRDD(self, rdd): raise NotImplementedError def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + + """ return PipelinedDStream(self, f, preservesPartitioning) + def _defaultReducePartitions(self): + """ + + """ + # hard code to avoid the error + return 2 + if self.ctx._conf.contains("spark.default.parallelism"): + return self.ctx.defaultParallelism + else: + return self.getNumPartitions() + + def getNumPartitions(self): + """ + Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> rdd.getNumPartitions() + 2 + """ + return self._jdstream.partitions().size() + class PipelinedDStream(DStream): def __init__(self, prev, func, preservesPartitioning=False): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 389136f9e21a0..719dd0a6a53c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -129,7 +129,7 @@ class PythonDStream[T: ClassTag]( } } -/* + private class PairwiseDStream(prev:DStream[Array[Byte]]) extends DStream[(Long, Array[Byte])](prev.ssc){ override def dependencies = List(prev) @@ -144,9 +144,9 @@ DStream[(Long, Array[Byte])](prev.ssc){ case None => None } } - val asJavaPairDStream : JavaPairDStream[Long, Array[Byte]] = JavaPairDStream(this) + val asJavaPairDStream : JavaPairDStream[Long, Array[Byte]] = JavaPairDStream.fromJavaDStream(this) } -*/ + From 69e9cd33a58b880f96cc9c3e5e62eaa415c49843 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:07:42 -0700 Subject: [PATCH 005/347] implementing transform function in Python --- python/pyspark/mllib/_common.py | 2 +- python/pyspark/streaming/dstream.py | 3 +- .../api/python/PythonTransformedDStream.scala | 37 +++++++++++++++++++ .../spark/streaming/dstream/DStream.scala | 3 ++ 4 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonTransformedDStream.scala diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e609b60a0f968..4b723693f43e3 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -164,7 +164,7 @@ def _deserialize_double_vector(ba, offset=0): nb = len(ba) - offset if nb < 5: raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % nb) + "which is too short" % nb) if ba[offset] == DENSE_VECTOR_MAGIC: return _deserialize_dense_vector(ba, offset) elif ba[offset] == SPARSE_VECTOR_MAGIC: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index e144f8bc1cc09..3365c6d69c1a2 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -172,7 +172,6 @@ def _mergeCombiners(iterator): return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - def partitionBy(self, numPartitions, partitionFunc=None): """ Return a copy of the DStream partitioned using the specified partitioner. @@ -231,6 +230,7 @@ def slice(self, fromTime, toTime): def transform(self, transformFunc): """ """ + self._jdstream.transform(transformFunc) raise NotImplementedError def transformWith(self, other, transformFunc): @@ -264,7 +264,6 @@ def _defaultReducePartitions(self): """ # hard code to avoid the error - return 2 if self.ctx._conf.contains("spark.default.parallelism"): return self.ctx.defaultParallelism else: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonTransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonTransformedDStream.scala new file mode 100644 index 0000000000000..ff70483b771a4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonTransformedDStream.scala @@ -0,0 +1,37 @@ +package org.apache.spark.streaming.api.python + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.JavaDStream +import org.apache.spark.streaming.{Time, Duration} +import org.apache.spark.streaming.dstream.DStream + +import scala.reflect.ClassTag + +/** + * Created by ken on 7/15/14. + */ +class PythonTransformedDStream[T: ClassTag]( + parents: Seq[DStream[T]], + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]] + ) extends DStream[Array[Byte]](parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + //pythonDStream compute + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq + Some() + } + val asJavaDStream = JavaDStream.fromDStream(this) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d9d5446b62e9f..67977244ef420 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -561,9 +561,12 @@ abstract class DStream[T: ClassTag] ( // because the DStream is reachable from the outer object here, and because // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean + + // serialized python val cleanedF = context.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) + // if transformfunc is fine, it is okay cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } new TransformedDStream[U](Seq(this), realTransformFunc) From 72bfc66074b2f35224f116759e0a47204a138f24 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:12:53 -0700 Subject: [PATCH 006/347] modified the code base on comment in https://github.com/tdas/spark/pull/10 --- core/pom.xml | 2 +- python/pyspark/streaming/__init__.py | 1 - python/pyspark/streaming/context.py | 5 +---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index a59fc9fc035d7..6abf8480d5da0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0 + 1.1.0-SNAPSHOT ../pom.xml diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py index 719592912e80c..e69de29bb2d1d 100644 --- a/python/pyspark/streaming/__init__.py +++ b/python/pyspark/streaming/__init__.py @@ -1 +0,0 @@ -__author__ = 'ktakagiw' diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index c8ae9c4af85c9..40e9d98942e2e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -1,6 +1,3 @@ -__author__ = 'ktakagiw' - - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -41,7 +38,7 @@ class StreamingContext(object): """ - Main entry point for Spark functionality. A StreamingContext represents the + Main entry point for Spark Streaming functionality. A StreamingContext represents the connection to a Spark cluster, and can be used to create L{RDD}s and broadcast variables on that cluster. """ From a7a0b5ce72e9bad14880f2285544d11d725f0f14 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:17:02 -0700 Subject: [PATCH 007/347] add coment for hack why PYSPARK_PYTHON is needed in spark-submit --- bin/spark-submit | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bin/spark-submit b/bin/spark-submit index ac275b7696d5c..fa022f707e572 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -37,6 +37,16 @@ done DEPLOY_MODE=${DEPLOY_MODE:-"client"} + +# This is a hack to make DStream.pyprint work. +# This will be removed after pyprint is moved to PythonDStream. +# Problem is that print function is in (Scala)DStream. +# Whenever python code is executed, we call PythonDStream which passes +# pythonExec(which python Spark should execute). +# Since pyprint is located in DStream, Spark does not know which python should use. +# In that case, get python path from PYSPARK_PYTHON, environmental variable. +# This fix is ongoing in print branch in my repo. + # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then PYSPARK_PYTHON="python" From 0a516f5a31bfb5f5d3ac58139af820ad8bb50a5a Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:19:13 -0700 Subject: [PATCH 008/347] add coment for hack why PYSPARK_PYTHON is needed in spark-submit --- bin/spark-submit | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-submit b/bin/spark-submit index fa022f707e572..ec4e10787cff0 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -45,7 +45,7 @@ DEPLOY_MODE=${DEPLOY_MODE:-"client"} # pythonExec(which python Spark should execute). # Since pyprint is located in DStream, Spark does not know which python should use. # In that case, get python path from PYSPARK_PYTHON, environmental variable. -# This fix is ongoing in print branch in my repo. +# This fix is ongoing in print branch in https://github.com/giwa/spark/tree/print. # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then From 57e3e52191464f6b8f8ec53a6452dcf86d4704a6 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:24:08 -0700 Subject: [PATCH 009/347] remove not implemented DStream functions in python --- python/pyspark/streaming/dstream.py | 102 ---------------------------- 1 file changed, 102 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index a512517f6e437..6ab9c500450aa 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -54,50 +54,6 @@ def pyprint(self): """ self._jdstream.pyprint() - def cache(self): - """ - """ - raise NotImplementedError - - def checkpoint(self): - """ - """ - raise NotImplementedError - - def compute(self, time): - """ - """ - raise NotImplementedError - - def context(self): - """ - """ - raise NotImplementedError - - def count(self): - """ - """ - raise NotImplementedError - - def countByValue(self, numPartitions=None): - """ - """ - raise NotImplementedError - - def countByValueAndWindow(self, duration, slideDuration=None): - """ - """ - raise NotImplementedError - - def countByWindow(self, duration, slideDuration=None): - """ - """ - raise NotImplementedError - - def dstream(self): - """ - """ - raise NotImplementedError def filter(self, f): """ @@ -111,16 +67,6 @@ def flatMap(self, f, preservesPartitioning=False): def func(s, iterator): return chain.from_iterable(imap(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) - def foreachRDD(self, f, time): - """ - """ - raise NotImplementedError - - def glom(self): - """ - """ - raise NotImplementedError - def map(self, f, preservesPartitioning=False): """ """ @@ -133,11 +79,6 @@ def mapPartitions(self, f): def func(s, iterator): return f(iterator) return self.mapPartitionsWithIndex(func) - def perist(self, storageLevel): - """ - """ - raise NotImplementedError - def reduce(self, func, numPartitions=None): """ @@ -210,49 +151,6 @@ def add_shuffle_key(split, iterator): dstream._partitionFunc = partitionFunc return dstream - - - def reduceByWindow(self, reduceFunc, windowDuration, slideDuration, inReduceTunc): - """ - """ - - raise NotImplementedError - - def repartition(self, numPartitions): - """ - """ - raise NotImplementedError - - def slice(self, fromTime, toTime): - """ - """ - raise NotImplementedError - - def transform(self, transformFunc): - """ - """ - raise NotImplementedError - - def transformWith(self, other, transformFunc): - """ - """ - raise NotImplementedError - - def union(self, that): - """ - """ - raise NotImplementedError - - def window(self, windowDuration, slideDuration=None): - """ - """ - raise NotImplementedError - - def wrapRDD(self, rdd): - """ - """ - raise NotImplementedError - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): return PipelinedDStream(self, f, preservesPartitioning) From c9d79dd381ee001eb5920ca865b5dc72f8b46a7f Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:35:59 -0700 Subject: [PATCH 010/347] revert pom.xml --- python/pyspark/streaming/pyprint.py | 2 +- streaming/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/pyprint.py b/python/pyspark/streaming/pyprint.py index fcdaca510812c..6e87c985a57e3 100644 --- a/python/pyspark/streaming/pyprint.py +++ b/python/pyspark/streaming/pyprint.py @@ -1,6 +1,6 @@ import sys from itertools import chain -from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.serializers import PickleSerializer def collect(binary_file_path): dse = PickleSerializer() diff --git a/streaming/pom.xml b/streaming/pom.xml index 88df63592efee..2239ad9c8579c 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0 + 1.1.0-SNAPSHOT ../pom.xml From 8f8202b5c9bfccfb42f7027e7e8079b4b5807f02 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:38:26 -0700 Subject: [PATCH 011/347] revert streaming pom.xml --- streaming/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/pom.xml b/streaming/pom.xml index 2239ad9c8579c..03102c5e836bf 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -76,7 +76,7 @@ are necessary - first one for 'mvn package', second one for 'mvn compile'. Ideally, 'mvn compile' should not compile test classes and therefore should not need this. However, an open Maven bug (http://jira.codehaus.org/browse/MNG-3559) - causes the compilation to fail if streaming test-jar is not generated. Hence, the + causes the compilation to fail if streaming test-jar is not generated. Hence, the second execution profile for 'mvn compile'. --> From fa4a7fc1b0643bfbe48b24e3897d65bce3332e64 Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Wed, 16 Jul 2014 11:44:14 -0700 Subject: [PATCH 012/347] revert streaming/pom.xml --- streaming/pom.xml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/pom.xml b/streaming/pom.xml index 03102c5e836bf..f506d6ce34a6f 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -69,12 +69,12 @@ org.scalatest scalatest-maven-plugin - - Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + + + // Get the RDD of the reduced value of the previous window + val previousWindowRDD = + getOrCompute(previousWindow.endTime) + + if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { + // subtle the values from old RDDs + val oldRDDs = + parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) + val subbed = if (oldRDDs.size > 0) { + invReduceFunc.call(JavaRDD.fromRDD(previousWindowRDD.get), + JavaRDD.fromRDD(ssc.sc.union(oldRDDs)), validTime.milliseconds).rdd + } else { + previousWindowRDD.get + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = + parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) + + if (newRDDs.size > 0) { + Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(newRDDs).union(subbed)), validTime.milliseconds)) + } else { + Some(subbed) + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = + parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) + if (currentRDDs.size > 0) { + Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) + } else { + None + } + } + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + + /** * This is used for foreachRDD() in Python */ From c28f520ec2e77c6a5f7139b5131182024eddd1be Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Sep 2014 13:56:50 -0700 Subject: [PATCH 304/347] support updateStateByKey --- python/pyspark/streaming/dstream.py | 30 +++++++++---- python/pyspark/streaming/tests.py | 19 ++++++++ python/pyspark/streaming/util.py | 11 ++--- .../streaming/api/python/PythonDStream.scala | 44 ++++++++++++++++--- 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 38bb54f25eaa2..27e1400b8ba0b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -366,8 +366,9 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration, numPartitions=None): reduced = self.reduceByKey(func) - def reduceFunc(a, t): - return a.reduceByKey(func, numPartitions) + def reduceFunc(a, b, t): + b = b.reduceByKey(func, numPartitions) + return a.union(b).reduceByKey(func, numPartitions) if a else b def invReduceFunc(a, b, t): b = b.reduceByKey(func, numPartitions) @@ -378,19 +379,30 @@ def invReduceFunc(a, b, t): windowDuration = Seconds(windowDuration) if not isinstance(slideDuration, Duration): slideDuration = Seconds(slideDuration) - serializer = reduced._jrdd_deserializer - jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer) dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, windowDuration._jduration, slideDuration._jduration) - return DStream(dstream.asJavaDStream(), self._ssc, serializer) + return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + :param updateFunc: [(k, vs, s)] -> [(k, s)] + """ + def reduceFunc(a, b, t): + if a is None: + g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) + else: + g = a.cogroup(b).map(lambda (k, (va, vb)): + (k, list(vb), list(va)[0] if len(va) else None)) + return g.mapPartitions(lambda x: updateFunc(x) or []) - def updateStateByKey(self, updateFunc): - # FIXME: convert updateFunc to java JFunction2 - jFunc = updateFunc - return self._jdstream.updateStateByKey(jFunc) + jreduceFunc = RDDFunction2(self.ctx, reduceFunc, + self.ctx.serializer, self._jrdd_deserializer) + dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) class TransformedDStream(DStream): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index aa20b7efbee46..755ea224e56da 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -294,6 +294,25 @@ def func(dstream): [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] self._test_func(input, func, expected) + def update_state_by_key(self): + + def updater(it): + for k, vs, s in it: + if not s: + s = vs + else: + s.extend(vs) + yield (k, s) + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + class TestStreamingContext(unittest.TestCase): def setUp(self): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4051732f25302..fdbd01ec1766d 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -50,15 +50,16 @@ class RDDFunction2(object): This class is for py4j callback. This class is related with org.apache.spark.streaming.api.python.PythonRDDFunction2. """ - def __init__(self, ctx, func, jrdd_deserializer): + def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None): self.ctx = ctx self.func = func - self.deserializer = jrdd_deserializer + self.jrdd_deserializer = jrdd_deserializer + self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer def call(self, jrdd, jrdd2, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None - other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None + rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None + other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None r = self.func(rdd, other, milliseconds) if r: return r._jrdd @@ -67,7 +68,7 @@ def call(self, jrdd, jrdd2, milliseconds): traceback.print_exc() def __repr__(self): - return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func)) + return "RDDFunction2(%s)" % (str(self.func)) class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2'] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 689c04fa49135..b904e273eb438 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -118,7 +118,7 @@ private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DSt private[spark] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction, + reduceFunc: PythonRDDFunction2, invReduceFunc: PythonRDDFunction2, _windowDuration: Duration, _slideDuration: Duration @@ -149,10 +149,6 @@ class PythonReducedWindowedDStream( override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - None - val reduceF = reduceFunc - val invReduceF = invReduceFunc - val currentTime = validTime val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime) @@ -196,7 +192,7 @@ class PythonReducedWindowedDStream( parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) if (newRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(newRDDs).union(subbed)), validTime.milliseconds)) + Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds)) } else { Some(subbed) } @@ -205,7 +201,7 @@ class PythonReducedWindowedDStream( val currentRDDs = parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) if (currentRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) + Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) } else { None } @@ -216,6 +212,40 @@ class PythonReducedWindowedDStream( } +/** + * Copied from ReducedWindowedDStream + */ +private[spark] +class PythonStateDStream( + parent: DStream[Array[Byte]], + reduceFunc: PythonRDDFunction2 + ) extends DStream[Array[Byte]](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val newRDD = parent.getOrCompute(validTime) + if (newRDD.isDefined) { + if (lastState.isDefined) { + Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) + } else { + Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) + } + } else { + lastState + } + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * This is used for foreachRDD() in Python */ From 3f0fb4b7e8265c9076077bc8290aeac3b9aeb18b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 00:15:52 -0700 Subject: [PATCH 305/347] refactor fix tests --- python/pyspark/serializers.py | 3 + python/pyspark/streaming/context.py | 129 +++++++++-- python/pyspark/streaming/dstream.py | 8 +- python/pyspark/streaming/tests.py | 62 +++-- python/pyspark/streaming/util.py | 13 +- .../streaming/api/python/PythonDStream.scala | 219 ++++++++++-------- 6 files changed, 288 insertions(+), 146 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..94bebc310bad6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 1c7cb5604e5cc..c4a1014ab9ab0 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,16 +15,51 @@ # limitations under the License. # -from pyspark.serializers import UTF8Deserializer +from pyspark import RDD +from pyspark.serializers import UTF8Deserializer, BatchedSerializer from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.duration import Duration, Seconds +from pyspark.streaming.duration import Seconds from py4j.java_collections import ListConverter __all__ = ["StreamingContext"] +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + # self.port = self.server_socket.getsockname()[1] + except Exception: + msg = 'An error occurred while trying to start the callback server' + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext represents the @@ -53,7 +88,9 @@ def _start_callback_server(self): gw = self._sc._gateway # getattr will fallback to JVM if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() gw._start_callback_server(gw._python_proxy_port) + gw._python_proxy_port = gw._callback_server.port # update port with real port def _initialize_context(self, sc, duration): return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration) @@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): def remember(self, duration): """ - Set each DStreams in this context to remember RDDs it generated in the last given duration. - DStreams remember RDDs only for a limited duration of time and releases them for garbage - collection. This method allows the developer to specify how to long to remember the RDDs ( - if the developer wishes to query old data outside the DStream computation). - @param duration pyspark.streaming.duration.Duration object or seconds. - Minimum duration that each DStream should remember its RDDs + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs ( if the developer wishes to query old data outside the + DStream computation). + + @param duration Minimum duration (in seconds) that each DStream + should remember its RDDs """ if isinstance(duration, (int, long, float)): duration = Seconds(duration) self._jssc.remember(duration._jduration) - # TODO: add storageLevel - def socketTextStream(self, hostname, port): + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited lines. + + @param hostname Hostname to connect to for receiving data + @param port Port to connect to for receiving data + @param storageLevel Storage level to use for storing the received objects """ - return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer()) + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) def textFileStream(self, directory): """ @@ -122,14 +177,52 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def _makeStream(self, inputs, numSlices=None): + def _check_serialzers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)): + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) + + def queueStream(self, queue, oneAtATime=False, default=None): """ - This function is only for unittest. - It requires a list as input, and returns the i_th element at the i_th batch - under manual clock. + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + @param queue Queue of RDDs + @tparam T Type of objects in the RDD """ - rdds = [self._sc.parallelize(input, numSlices) for input in inputs] + if queue and not isinstance(queue[0], RDD): + rdds = [self._sc.parallelize(input) for input in queue] + else: + rdds = queue + self._check_serialzers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) - jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream() - return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime, + default and default._jrdd) + return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying a function on RDDs of + the DStreams. The order of the JavaRDDs in the transform function parameter will be the + same as the order of corresponding DStreams in the list. + """ + # TODO + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + self._check_serialzers(dstreams) + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 27e1400b8ba0b..9dd3556327477 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -315,16 +315,16 @@ def repartitions(self, numPartitions): return self.transform(lambda rdd: rdd.repartition(numPartitions)) def union(self, other): - return self.transformWith(lambda a, b: a.union(b), other, True) + return self.transformWith(lambda a, b, t: a.union(b), other, True) def cogroup(self, other): - return self.transformWith(lambda a, b: a.cogroup(b), other) + return self.transformWith(lambda a, b, t: a.cogroup(b), other) def leftOuterJoin(self, other): - return self.transformWith(lambda a, b: a.leftOuterJion(b), other) + return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other) def rightOuterJoin(self, other): - return self.transformWith(lambda a, b: a.rightOuterJoin(b), other) + return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other) def _jtime(self, milliseconds): return self.ctx._jvm.Time(milliseconds) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 755ea224e56da..a585bbfa06f5b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -40,27 +40,25 @@ def setUp(self): class_name = self.__class__.__name__ self.sc = SparkContext(appName=class_name) self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests self.ssc = StreamingContext(self.sc, duration=Seconds(1)) def tearDown(self): self.ssc.stop() - self.sc.stop() @classmethod def tearDownClass(cls): # Make sure tp shutdown the callback server SparkContext._gateway._shutdown_callback_server() - def _test_func(self, input, func, expected, numSlices=None, sort=False): + def _test_func(self, input, func, expected, sort=False): """ - Start stream and return the result. @param input: dataset for the test. This should be list of lists. @param func: wrapped function. This function should return PythonDStream object. @param expected: expected output for this testcase. - @param numSlices: the number of slices in the rdd in the dstream. """ # Generate input stream with user-defined input. - input_stream = self.ssc._makeStream(input, numSlices) + input_stream = self.ssc.queueStream(input) # Apply test function to stream. stream = func(input_stream) result = stream.collect() @@ -121,7 +119,7 @@ def func(dstream): def test_count(self): """Basic operation test for DStream.count.""" - input = [range(1, 5), range(1, 10), range(1, 20)] + input = [range(5), range(10), range(20)] def func(dstream): return dstream.count() @@ -178,24 +176,24 @@ def func(dstream): def test_glom(self): """Basic operation test for DStream.glom.""" input = [range(1, 5), range(5, 9), range(9, 13)] - numSlices = 2 + rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): return dstream.glom() expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] - self._test_func(input, func, expected, numSlices) + self._test_func(rdds, func, expected) def test_mapPartitions(self): """Basic operation test for DStream.mapPartitions.""" input = [range(1, 5), range(5, 9), range(9, 13)] - numSlices = 2 + rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): def f(iterator): yield sum(iterator) return dstream.mapPartitions(f) expected = [[3, 7], [11, 15], [19, 23]] - self._test_func(input, func, expected, numSlices) + self._test_func(rdds, func, expected) def test_countByValue(self): """Basic operation test for DStream.countByValue.""" @@ -236,14 +234,14 @@ def add(a, b): self._test_func(input, func, expected, sort=True) def test_union(self): - input1 = [range(3), range(5), range(1)] + input1 = [range(3), range(5), range(1), range(6)] input2 = [range(3, 6), range(5, 6), range(1, 6)] - d1 = self.ssc._makeStream(input1) - d2 = self.ssc._makeStream(input2) + d1 = self.ssc.queueStream(input1) + d2 = self.ssc.queueStream(input2) d = d1.union(d2) result = d.collect() - expected = [range(6), range(6), range(6)] + expected = [range(6), range(6), range(6), range(6)] self.ssc.start() start_time = time.time() @@ -317,33 +315,49 @@ def func(dstream): class TestStreamingContext(unittest.TestCase): def setUp(self): self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__) - self.batachDuration = Seconds(1) - self.ssc = None + self.batachDuration = Seconds(0.1) + self.ssc = StreamingContext(self.sc, self.batachDuration) def tearDown(self): - if self.ssc is not None: - self.ssc.stop() + self.ssc.stop() self.sc.stop() def test_stop_only_streaming_context(self): - self.ssc = StreamingContext(self.sc, self.batachDuration) - self._addInputStream(self.ssc) + self._addInputStream() self.ssc.start() self.ssc.stop(False) self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) def test_stop_multiple_times(self): - self.ssc = StreamingContext(self.sc, self.batachDuration) - self._addInputStream(self.ssc) + self._addInputStream() self.ssc.start() self.ssc.stop() self.ssc.stop() - def _addInputStream(self, s): + def _addInputStream(self): # Make sure each length of input is over 3 inputs = map(lambda x: range(1, x), range(5, 101)) - stream = s._makeStream(inputs) + stream = self.ssc.queueStream(inputs) stream.collect() + def test_queueStream(self): + input = [range(i) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = dstream.collect() + self.ssc.start() + time.sleep(1) + self.assertEqual(input, result) + + def test_union(self): + input = [range(i) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.union(dstream, dstream) + result = dstream.collect() + self.ssc.start() + time.sleep(1) + expected = [i * 2 for i in input] + self.assertEqual(input, result) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index fdbd01ec1766d..feff1b3889c49 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -30,7 +30,10 @@ def __init__(self, ctx, func, jrdd_deserializer): def call(self, jrdd, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.deserializer) + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD r = self.func(rdd, milliseconds) if r: return r._jrdd @@ -58,8 +61,12 @@ def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None): def call(self, jrdd, jrdd2, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None - other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + + rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD + other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD r = self.func(rdd, other, milliseconds) if r: return r._jrdd diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index b904e273eb438..828a620e4c08f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -39,6 +39,22 @@ trait PythonRDDFunction { def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } +class RDDFunction(pfunc: PythonRDDFunction) { + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val jrdd = if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + val r = pfunc.call(jrdd, time.milliseconds) + if (r != null) { + Some(r.rdd) + } else { + None + } + } +} + /** * Interface for Python callback function with three arguments */ @@ -46,33 +62,61 @@ trait PythonRDDFunction2 { def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } +class RDDFunction2(pfunc: PythonRDDFunction2) { + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val jrdd = if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + val jrdd2 = if (rdd2.isDefined) { + JavaRDD.fromRDD(rdd2.get) + } else { + null + } + val r = pfunc.call(jrdd, jrdd2, time.milliseconds) + if (r != null) { + Some(r.rdd) + } else { + None + } + } +} + +private[python] +abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * Transformed DStream in Python. * * If the result RDD is PythonRDD, then it will cache it as an template for future use, * this can reduce the Python callbacks. */ -private[spark] class PythonTransformedDStream (parent: DStream[_], func: PythonRDDFunction, +private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) - extends DStream[Array[Byte]] (parent.ssc) { + extends PythonDStream(parent) { + val func = new RDDFunction(pfunc) var lastResult: PythonRDD = _ - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val rdd1 = parent.getOrCompute(validTime).getOrElse(null) - if (rdd1 == null) { + val rdd1 = parent.getOrCompute(validTime) + if (rdd1.isEmpty) { return None } if (reuse && lastResult != null) { - Some(lastResult.copyTo(rdd1)) + Some(lastResult.copyTo(rdd1.get)) } else { - val r = func.call(JavaRDD.fromRDD(rdd1), validTime.milliseconds).rdd - if (reuse && lastResult == null) { - r match { + val r = func(rdd1, validTime) + if (reuse && r.isDefined && lastResult == null) { + r.get match { case rdd: PythonRDD => if (rdd.parent(0) == rdd1) { // only one PythonRDD @@ -83,46 +127,65 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], func: PythonR } } } - Some(r) + r } } - - val asJavaDStream = JavaDStream.fromDStream(this) } /** * Transformed from two DStreams in Python. */ -private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction2) +private[spark] +class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], + pfunc: PythonRDDFunction2) extends DStream[Array[Byte]] (parent.ssc) { - override def dependencies = List(parent, parent2) + val func = new RDDFunction2(pfunc) override def slideDuration: Duration = parent.slideDuration + override def dependencies = List(parent, parent2) + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - def resultRdd(stream: DStream[_]): JavaRDD[_] = stream.getOrCompute(validTime) match { - case Some(rdd) => JavaRDD.fromRDD(rdd) - case None => null - } - Some(func.call(resultRdd(parent), resultRdd(parent2), validTime.milliseconds)) + func(parent.getOrCompute(validTime), parent2.getOrCompute(validTime), validTime) } val asJavaDStream = JavaDStream.fromDStream(this) } +/** + * similar to StateDStream + */ +private[spark] +class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction2) + extends PythonDStream(parent) { + + val reduceFunc = new RDDFunction2(preduceFunc) + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + reduceFunc(lastState, rdd, validTime) + } else { + lastState + } + } +} /** * Copied from ReducedWindowedDStream */ private[spark] -class PythonReducedWindowedDStream( - parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction2, - invReduceFunc: PythonRDDFunction2, - _windowDuration: Duration, - _slideDuration: Duration - ) extends DStream[Array[Byte]](parent.ssc) { +class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], + preduceFunc: PythonRDDFunction2, + pinvReduceFunc: PythonRDDFunction2, + _windowDuration: Duration, + _slideDuration: Duration + ) extends PythonStateDStream(parent, preduceFunc) { assert(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + @@ -134,18 +197,10 @@ class PythonReducedWindowedDStream( "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) + val invReduceFunc = new RDDFunction2(pinvReduceFunc) - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY) - - def windowDuration: Duration = _windowDuration - - override def dependencies = List(parent) - + def windowDuration: Duration = _windowDuration override def slideDuration: Duration = _slideDuration - - override val mustCheckpoint = true - override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -171,20 +226,17 @@ class PythonReducedWindowedDStream( // old RDDs new RDDs // - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = - getOrCompute(previousWindow.endTime) + val previousWindowRDD = getOrCompute(previousWindow.endTime) if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { // subtle the values from old RDDs val oldRDDs = parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) val subbed = if (oldRDDs.size > 0) { - invReduceFunc.call(JavaRDD.fromRDD(previousWindowRDD.get), - JavaRDD.fromRDD(ssc.sc.union(oldRDDs)), validTime.milliseconds).rdd + invReduceFunc(previousWindowRDD, Some(ssc.sc.union(oldRDDs)), validTime) } else { - previousWindowRDD.get + previousWindowRDD } // add the RDDs of the reduced values in "new time steps" @@ -192,58 +244,21 @@ class PythonReducedWindowedDStream( parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) if (newRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds)) + reduceFunc(subbed, Some(ssc.sc.union(newRDDs)), validTime) } else { - Some(subbed) + subbed } } else { // Get the RDDs of the reduced values in current window val currentRDDs = parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) if (currentRDDs.size > 0) { - Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) + reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime) } else { None } } } - - val asJavaDStream = JavaDStream.fromDStream(this) -} - - -/** - * Copied from ReducedWindowedDStream - */ -private[spark] -class PythonStateDStream( - parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction2 - ) extends DStream[Array[Byte]](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY) - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val lastState = getOrCompute(validTime - slideDuration) - val newRDD = parent.getOrCompute(validTime) - if (newRDD.isDefined) { - if (lastState.isDefined) { - Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) - } else { - Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) - } - } else { - lastState - } - } - - val asJavaDStream = JavaDStream.fromDStream(this) } /** @@ -255,7 +270,9 @@ class PythonForeachDStream( ) extends ForEachDStream[Array[Byte]]( prev, (rdd: RDD[Array[Byte]], time: Time) => { - foreachFunction.call(rdd.toJavaRDD(), time.milliseconds) + if (rdd != null) { + foreachFunction.call(rdd, time.milliseconds) + } } ) { @@ -264,34 +281,42 @@ class PythonForeachDStream( /** - * This is a input stream just for the unitest. This is equivalent to a checkpointable, - * replayable, reliable message queue like Kafka. It requires a JArrayList of JavaRDD, - * and returns the i_th element at the i_th batch under manual clock. + * similar to QueueInputStream */ class PythonDataInputStream( ssc_ : JavaStreamingContext, - inputRDDs: JArrayList[JavaRDD[Array[Byte]]] + inputRDDs: JArrayList[JavaRDD[Array[Byte]]], + oneAtAtime: Boolean, + defaultRDD: JavaRDD[Array[Byte]] ) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) { + val emptyRDD = if (defaultRDD != null) { + Some(defaultRDD.rdd) + } else { + None // ssc.sparkContext.emptyRDD[Array[Byte]] + } + def start() {} def stop() {} def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val emptyRDD = ssc.sparkContext.emptyRDD[Array[Byte]] val index = ((validTime - zeroTime) / slideDuration - 1).toInt - val selectedRDD = { - if (inputRDDs.isEmpty) { + if (oneAtAtime) { + if (index == 0) { + val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq + Some(ssc.sparkContext.union(rdds)) + } else { emptyRDD - } else if (index < inputRDDs.size()) { - inputRDDs.get(index).rdd + } + } else { + if (index < inputRDDs.size()) { + Some(inputRDDs.get(index).rdd) } else { emptyRDD } } - - Some(selectedRDD) } val asJavaDStream = JavaDStream.fromDStream(this) From c499ba0e48c10b5aa587e81c179f02c1b88e2045 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 00:26:12 -0700 Subject: [PATCH 306/347] remove Time and Duration --- python/pyspark/streaming/context.py | 20 +- python/pyspark/streaming/dstream.py | 13 +- python/pyspark/streaming/duration.py | 401 --------------------------- python/pyspark/streaming/jtime.py | 135 --------- python/pyspark/streaming/tests.py | 4 +- 5 files changed, 14 insertions(+), 559 deletions(-) delete mode 100644 python/pyspark/streaming/duration.py delete mode 100644 python/pyspark/streaming/jtime.py diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index c4a1014ab9ab0..88e0cbbede1be 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -16,11 +16,10 @@ # from pyspark import RDD -from pyspark.serializers import UTF8Deserializer, BatchedSerializer +from pyspark.serializers import UTF8Deserializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.duration import Seconds from py4j.java_collections import ListConverter @@ -76,9 +75,6 @@ def __init__(self, sparkContext, duration): @param duration: A L{Duration} object or seconds for SparkStreaming. """ - if isinstance(duration, (int, long, float)): - duration = Seconds(duration) - self._sc = sparkContext self._jvm = self._sc._jvm self._start_callback_server() @@ -93,7 +89,10 @@ def _start_callback_server(self): gw._python_proxy_port = gw._callback_server.port # update port with real port def _initialize_context(self, sc, duration): - return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration) + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + return self._jvm.Duration(int(seconds * 1000)) @property def sparkContext(self): @@ -111,12 +110,12 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. - @param timeout: time to wait in milliseconds + @param timeout: time to wait in seconds """ if timeout is None: self._jssc.awaitTermination() else: - self._jssc.awaitTermination(timeout) + self._jssc.awaitTermination(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ @@ -139,10 +138,7 @@ def remember(self, duration): @param duration Minimum duration (in seconds) that each DStream should remember its RDDs """ - if isinstance(duration, (int, long, float)): - duration = Seconds(duration) - - self._jssc.remember(duration._jduration) + self._jssc.remember(self._jduration(duration)) def checkpoint(self, directory): """ diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 9dd3556327477..8c79eece773ce 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -22,7 +22,6 @@ from pyspark.storagelevel import StorageLevel from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2 from pyspark.rdd import portable_hash -from pyspark.streaming.duration import Duration, Seconds from pyspark.resultiterable import ResultIterable __all__ = ["DStream"] @@ -334,10 +333,10 @@ def slice(self, begin, end): return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] def window(self, windowDuration, slideDuration=None): - d = Seconds(windowDuration) + d = self._ssc._jduration(windowDuration) if slideDuration is None: return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) - s = Seconds(slideDuration) + s = self._ssc._jduration(slideDuration) return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): @@ -375,16 +374,12 @@ def invReduceFunc(a, b, t): joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) - if not isinstance(windowDuration, Duration): - windowDuration = Seconds(windowDuration) - if not isinstance(slideDuration, Duration): - slideDuration = Seconds(slideDuration) jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer) dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, - windowDuration._jduration, - slideDuration._jduration) + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) def updateStateByKey(self, updateFunc, numPartitions=None): diff --git a/python/pyspark/streaming/duration.py b/python/pyspark/streaming/duration.py deleted file mode 100644 index 8660f332a48da..0000000000000 --- a/python/pyspark/streaming/duration.py +++ /dev/null @@ -1,401 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -def msDurationToString(ms): - """ - Returns a human-readable string representing a duration such as "35ms" - - >> msDurationToString(10) - '10 ms' - >>> msDurationToString(1000) - '1.0 s' - >>> msDurationToString(60000) - '1.0 m' - >>> msDurationToString(3600000) - '1.00 h' - """ - second = 1000 - minute = 60 * second - hour = 60 * minute - - if ms < second: - return "%d ms" % ms - elif ms < minute: - return "%.1f s" % (float(ms) / second) - elif ms < hour: - return "%.1f m" % (float(ms) / minute) - else: - return "%.2f h" % (float(ms) / hour) - - -class Duration(object): - """ - Duration for Spark Streaming application. Used to set duration - - Most of the time, you would create a Duration object with - C{Duration()}, which will load values from C{spark.streaming.*} Java system - properties as well. In this case, any parameters you set directly on - the C{Duration} object take priority over system properties. - - """ - def __init__(self, millis, _jvm=None): - """ - Create new Duration. - - @param millis: milisecond - - """ - self._millis = millis - - from pyspark.context import SparkContext - SparkContext._ensure_initialized() - _jvm = _jvm or SparkContext._jvm - self._jduration = _jvm.Duration(millis) - - def toString(self): - """ - Return duration as string - - >>> d_10 = Duration(10) - >>> d_10.toString() - '10 ms' - """ - return str(self._millis) + " ms" - - def isZero(self): - """ - Check if millis is zero - - >>> d_10 = Duration(10) - >>> d_10.isZero() - False - >>> d_0 = Duration(0) - >>> d_0.isZero() - True - """ - return self._millis == 0 - - def prettyPrint(self): - """ - Return a human-readable string representing a duration - - >>> d_10 = Duration(10) - >>> d_10.prettyPrint() - '10 ms' - >>> d_1sec = Duration(1000) - >>> d_1sec.prettyPrint() - '1.0 s' - >>> d_1min = Duration(60 * 1000) - >>> d_1min.prettyPrint() - '1.0 m' - >>> d_1hour = Duration(60 * 60 * 1000) - >>> d_1hour.prettyPrint() - '1.00 h' - """ - return msDurationToString(self._millis) - - def milliseconds(self): - """ - Return millisecond - - >>> d_10 = Duration(10) - >>> d_10.milliseconds() - 10 - - """ - return self._millis - - def toFormattedString(self): - """ - Return millisecond - - >>> d_10 = Duration(10) - >>> d_10.toFormattedString() - '10' - - """ - return str(self._millis) - - def max(self, other): - """ - Return higher Duration - - >>> d_10 = Duration(10) - >>> d_100 = Duration(100) - >>> d_max = d_10.max(d_100) - >>> print d_max - 100 ms - - """ - Duration._is_duration(other) - if self > other: - return self - else: - return other - - def min(self, other): - """ - Return lower Durattion - - >>> d_10 = Duration(10) - >>> d_100 = Duration(100) - >>> d_min = d_10.min(d_100) - >>> print d_min - 10 ms - - """ - Duration._is_duration(other) - if self < other: - return self - else: - return other - - def __str__(self): - """ - >>> d_10 = Duration(10) - >>> str(d_10) - '10 ms' - - """ - return self.toString() - - def __add__(self, other): - """ - Add Duration and Duration - - >>> d_10 = Duration(10) - >>> d_100 = Duration(100) - >>> d_110 = d_10 + d_100 - >>> print d_110 - 110 ms - """ - Duration._is_duration(other) - return Duration(self._millis + other._millis) - - def __sub__(self, other): - """ - Subtract Duration by Duration - - >>> d_10 = Duration(10) - >>> d_100 = Duration(100) - >>> d_90 = d_100 - d_10 - >>> print d_90 - 90 ms - - """ - Duration._is_duration(other) - return Duration(self._millis - other._millis) - - def __mul__(self, other): - """ - Multiple Duration by Duration - - >>> d_10 = Duration(10) - >>> d_100 = Duration(100) - >>> d_1000 = d_10 * d_100 - >>> print d_1000 - 1000 ms - - """ - Duration._is_duration(other) - return Duration(self._millis * other._millis) - - def __div__(self, other): - """ - Divide Duration by Duration - for Python 2.X - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_2 = d_20 / d_10 - >>> print d_2 - 2 ms - - """ - Duration._is_duration(other) - return Duration(self._millis / other._millis) - - def __truediv__(self, other): - """ - Divide Duration by Duration - for Python 3.0 - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_2 = d_20 / d_10 - >>> print d_2 - 2 ms - - """ - Duration._is_duration(other) - return Duration(self._millis / other._millis) - - def __floordiv__(self, other): - """ - Divide Duration by Duration - - >>> d_10 = Duration(10) - >>> d_3 = Duration(3) - >>> d_3 = d_10 // d_3 - >>> print d_3 - 3 ms - - """ - Duration._is_duration(other) - return Duration(self._millis // other._millis) - - def __lt__(self, other): - """ - Duration < Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 < d_20 - True - >>> d_20 < d_10 - False - - """ - Duration._is_duration(other) - return self._millis < other._millis - - def __le__(self, other): - """ - Duration <= Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 <= d_20 - True - >>> d_20 <= d_10 - False - - """ - Duration._is_duration(other) - return self._millis <= other._millis - - def __eq__(self, other): - """ - Duration == Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 == d_20 - False - >>> other_d_10 = Duration(10) - >>> d_10 == other_d_10 - True - - """ - Duration._is_duration(other) - return self._millis == other._millis - - def __ne__(self, other): - """ - Duration != Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 != d_20 - True - >>> other_d_10 = Duration(10) - >>> d_10 != other_d_10 - False - - """ - Duration._is_duration(other) - return self._millis != other._millis - - def __gt__(self, other): - """ - Duration > Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 > d_20 - False - >>> d_20 > d_10 - True - - """ - Duration._is_duration(other) - return self._millis > other._millis - - def __ge__(self, other): - """ - Duration >= Duration - - >>> d_10 = Duration(10) - >>> d_20 = Duration(20) - >>> d_10 < d_20 - True - >>> d_20 < d_10 - False - - - """ - Duration._is_duration(other) - return self._millis >= other._millis - - @classmethod - def _is_duration(self, instance): - """ is instance Duration """ - if not isinstance(instance, Duration): - raise TypeError("This should be Duration") - - -def Milliseconds(milliseconds): - """ - Helper function that creates instance of [[pysparkstreaming.duration]] representing - a given number of milliseconds. - - >>> milliseconds = Milliseconds(1) - >>> d_1 = Duration(1) - >>> milliseconds == d_1 - True - - """ - return Duration(milliseconds) - - -def Seconds(seconds): - """ - Helper function that creates instance of [[pysparkstreaming.duration]] representing - a given number of seconds. - - >>> seconds = Seconds(1) - >>> d_1sec = Duration(1000) - >>> seconds == d_1sec - True - - """ - return Duration(seconds * 1000) - - -def Minutes(minutes): - """ - Helper function that creates instance of [[pysparkstreaming.duration]] representing - a given number of minutes. - - >>> minutes = Minutes(1) - >>> d_1min = Duration(60 * 1000) - >>> minutes == d_1min - True - - """ - return Duration(minutes * 60 * 1000) diff --git a/python/pyspark/streaming/jtime.py b/python/pyspark/streaming/jtime.py deleted file mode 100644 index e157640afa4df..0000000000000 --- a/python/pyspark/streaming/jtime.py +++ /dev/null @@ -1,135 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from pyspark.streaming.duration import Duration - -""" -The name of this file, time is not a good naming for python -because if we do import time when we want to use native python time package, it does -not import python time package. -""" -# TODO: add doctest - - -class Time(object): - """ - Time for Spark Streaming application. Used to set Time - - Most of the time, you would create a Duration object with - C{Time()}, which will load values from C{spark.streaming.*} Java system - properties as well. In this case, any parameters you set directly on - the C{Time} object take priority over system properties. - - """ - def __init__(self, millis, _jvm=None): - """ - Create new Time. - - @param millis: milisecond - - @param _jvm: internal parameter used to pass a handle to the - Java VM; does not need to be set by users - - """ - self._millis = millis - - from pyspark.context import StreamingContext - StreamingContext._ensure_initialized() - _jvm = _jvm or StreamingContext._jvm - self._jtime = _jvm.Time(millis) - - def toString(self): - """ Return time as string """ - return str(self._millis) + " ms" - - def milliseconds(self): - """ Return millisecond """ - return self._millis - - def max(self, other): - """ Return higher Time """ - Time._is_time(other) - if self > other: - return self - else: - return other - - def min(self, other): - """ Return lower Time """ - Time._is_time(other) - if self < other: - return self - else: - return other - - def __add__(self, other): - """ Add Time and Time """ - Duration._is_duration(other) - return Time(self._millis + other._millis) - - def __sub__(self, other): - """ Subtract Time by Duration or Time """ - if isinstance(other, Duration): - return Time(self._millis - other._millis) - elif isinstance(other, Time): - return Duration(self._millis, other._millis) - else: - raise TypeError - - def __lt__(self, other): - """ Time < Time """ - Time._is_time(other) - return self._millis < other._millis - - def __le__(self, other): - """ Time <= Time """ - Time._is_time(other) - return self._millis <= other._millis - - def __eq__(self, other): - """ Time == Time """ - Time._is_time(other) - return self._millis == other._millis - - def __ne__(self, other): - """ Time != Time """ - Time._is_time(other) - return self._millis != other._millis - - def __gt__(self, other): - """ Time > Time """ - Time._is_time(other) - return self._millis > other._millis - - def __ge__(self, other): - """ Time >= Time """ - Time._is_time(other) - return self._millis >= other._millis - - def isMultipbleOf(self, duration): - """ is multiple by Duration """ - Duration._is_duration(duration) - return self._millis % duration._millis == 0 - - @classmethod - def _is_time(self, instance): - """ is instance Time """ - if not isinstance(instance, Time): - raise TypeError - -# TODO: implement until -# TODO: implement to diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a585bbfa06f5b..1684da580f973 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,7 +41,7 @@ def setUp(self): self.sc = SparkContext(appName=class_name) self.sc.setCheckpointDir("/tmp") # TODO: decrease duration to speed up tests - self.ssc = StreamingContext(self.sc, duration=Seconds(1)) + self.ssc = StreamingContext(self.sc, duration=1) def tearDown(self): self.ssc.stop() @@ -315,7 +315,7 @@ def func(dstream): class TestStreamingContext(unittest.TestCase): def setUp(self): self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__) - self.batachDuration = Seconds(0.1) + self.batachDuration = 0.1 self.ssc = StreamingContext(self.sc, self.batachDuration) def tearDown(self): From 604323fd39a3b3f0f39540bf71fc737630f2b110 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 00:30:06 -0700 Subject: [PATCH 307/347] enable streaming tests --- python/run-tests | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/run-tests b/python/run-tests index 79d7602ccbc87..5aa9212c8adc1 100755 --- a/python/run-tests +++ b/python/run-tests @@ -70,8 +70,6 @@ export PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" -run_test "pyspark/streaming/duration.py" -run_test "pyspark/streaming/util.py" unset PYSPARK_DOC_TEST run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" @@ -85,9 +83,7 @@ run_test "pyspark/mllib/stat.py" run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" -if [ -n "$_RUN_STREAMING_TESTS" ]; then - run_test "pyspark/streaming/tests.py" -fi +run_test "pyspark/streaming/tests.py" # Try to test with PyPy if [ $(which pypy) ]; then @@ -108,6 +104,7 @@ if [ $(which pypy) ]; then unset PYSPARK_DOC_TEST run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" + run_test "pyspark/streaming/tests.py" fi if [[ $FAILED == 0 ]]; then From b32774cc3cc7493b360bd9e5b8b01df28968d0c2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 00:36:43 -0700 Subject: [PATCH 308/347] move java_import into streaming --- python/pyspark/java_gateway.py | 4 +--- python/pyspark/streaming/context.py | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c3fef42d118bd..db5b97f8472d1 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -23,6 +23,7 @@ import platform from subprocess import Popen, PIPE from threading import Thread + from py4j.java_gateway import java_import, JavaGateway, GatewayClient @@ -108,9 +109,6 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") - java_import(gateway.jvm, "org.apache.spark.streaming.*") - java_import(gateway.jvm, "org.apache.spark.streaming.api.java.*") - java_import(gateway.jvm, "org.apache.spark.streaming.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 88e0cbbede1be..a647c9ec734df 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -22,6 +22,7 @@ from pyspark.streaming.dstream import DStream from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import __all__ = ["StreamingContext"] @@ -72,7 +73,7 @@ def __init__(self, sparkContext, duration): should be set, either through the named parameters here or through C{conf}. @param sparkContext: L{SparkContext} object. - @param duration: A L{Duration} object or seconds for SparkStreaming. + @param duration: seconds for SparkStreaming. """ self._sc = sparkContext @@ -89,6 +90,9 @@ def _start_callback_server(self): gw._python_proxy_port = gw._callback_server.port # update port with real port def _initialize_context(self, sc, duration): + java_import(self._jvm, "org.apache.spark.streaming.*") + java_import(self._jvm, "org.apache.spark.streaming.api.java.*") + java_import(self._jvm, "org.apache.spark.streaming.api.python.*") return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) def _jduration(self, seconds): @@ -217,7 +221,6 @@ def union(self, *dstreams): raise ValueError("should have at least one DStream to union") if len(dstreams) == 1: return dstreams[0] - self._check_serialzers(dstreams) first = dstreams[0] jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], SparkContext._gateway._gateway_client) From 74df565e26e9bf7b107cc678e1668dfda7d534ef Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 00:48:03 -0700 Subject: [PATCH 309/347] fix print and docs --- python/pyspark/streaming/dstream.py | 56 ++++++++++++----------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8c79eece773ce..01ca56a7a0387 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -17,6 +17,7 @@ from itertools import chain, ifilter, imap import operator +from datetime import datetime from pyspark import RDD from pyspark.storagelevel import StorageLevel @@ -54,17 +55,6 @@ def sum(self): """ return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) - def print_(self, label=None): - """ - Since print is reserved name for python, we cannot define a "print" method function. - This function prints serialized data in RDD in DStream because Scala and Java cannot - deserialized pickled python object. Please use DStream.pyprint() to print results. - - Call DStream.print() and this function will print byte array in the DStream - """ - # a hack to call print function in DStream - getattr(self._jdstream, "print")(label) - def filter(self, f): """ Return a new DStream containing only the elements that satisfy predicate. @@ -154,19 +144,15 @@ def foreachRDD(self, func): jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc) - def pyprint(self): + def pprint(self): """ Print the first ten elements of each RDD generated in this DStream. This is an output operator, so this DStream will be registered as an output stream and there materialized. """ def takeAndPrint(rdd, time): - """ - Closure to take element from RDD and print first 10 elements. - This closure is called by py4j callback server. - """ taken = rdd.take(11) print "-------------------------------------------" - print "Time: %s" % (str(time)) + print "Time: %s" % datetime.fromtimestamp(time / 1000.0) print "-------------------------------------------" for record in taken[:10]: print record @@ -176,6 +162,20 @@ def takeAndPrint(rdd, time): self.foreachRDD(takeAndPrint) + def collect(self): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(rdd, time): + r = rdd.collect() + result.append(r) + self.foreachRDD(get_output) + return result + def mapValues(self, f): """ Pass each value in the key-value pair RDD through a map function @@ -196,9 +196,9 @@ def flatMapValues(self, f): def glom(self): """ - Return a new DStream in which RDD is generated by applying glom() to RDD of - this DStream. Applying glom() to an RDD coalesces all elements within each partition into - an list. + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. Applying glom() to an RDD coalesces all + elements within each partition into an list. """ def func(iterator): yield list(iterator) @@ -228,11 +228,11 @@ def checkpoint(self, interval): Mark this DStream for checkpointing. It will be saved to a file inside the checkpoint directory set with L{SparkContext.setCheckpointDir()} - @param interval: Time interval after which generated RDD will be checkpointed - interval has to be pyspark.streaming.duration.Duration + @param interval: time in seconds, after which generated RDD will + be checkpointed """ self.is_checkpointed = True - self._jdstream.checkpoint(interval._jduration) + self._jdstream.checkpoint(self._ssc._jduration(interval)) return self def groupByKey(self, numPartitions=None): @@ -245,7 +245,6 @@ def groupByKey(self, numPartitions=None): Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey will provide much better performance. - """ return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) @@ -288,15 +287,6 @@ def saveAsPickleFile(rdd, time): return self.foreachRDD(saveAsPickleFile) - def collect(self): - result = [] - - def get_output(rdd, time): - r = rdd.collect() - result.append(r) - self.foreachRDD(get_output) - return result - def transform(self, func): return TransformedDStream(self, lambda a, t: func(a), True) From 26ea39619c59f28b1ad18b8e44abef25d8d1dbae Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 21:17:23 -0700 Subject: [PATCH 310/347] refactor --- python/pyspark/streaming/dstream.py | 25 ++++---- python/pyspark/streaming/tests.py | 9 +-- python/pyspark/streaming/util.py | 47 +++----------- .../streaming/api/python/PythonDStream.scala | 63 +++++-------------- 4 files changed, 40 insertions(+), 104 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 01ca56a7a0387..d41eca020feb1 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -21,7 +21,7 @@ from pyspark import RDD from pyspark.storagelevel import StorageLevel -from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2 +from pyspark.streaming.util import rddToFileName, RDDFunction from pyspark.rdd import portable_hash from pyspark.resultiterable import ResultIterable @@ -141,7 +141,7 @@ def foreachRDD(self, func): This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer) self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc) def pprint(self): @@ -294,7 +294,7 @@ def transformWithTime(self, func): return TransformedDStream(self, func, False) def transformWith(self, func, other, keepSerializer=False): - jfunc = RDDFunction2(self.ctx, func, self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer @@ -304,16 +304,16 @@ def repartitions(self, numPartitions): return self.transform(lambda rdd: rdd.repartition(numPartitions)) def union(self, other): - return self.transformWith(lambda a, b, t: a.union(b), other, True) + return self.transformWith(lambda a, b: a.union(b), other, True) def cogroup(self, other): - return self.transformWith(lambda a, b, t: a.cogroup(b), other) + return self.transformWith(lambda a, b: a.cogroup(b), other) def leftOuterJoin(self, other): - return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other) + return self.transformWith(lambda a, b: a.leftOuterJion(b), other) def rightOuterJoin(self, other): - return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other) + return self.transformWith(lambda a, b: a.rightOuterJoin(b), other) def _jtime(self, milliseconds): return self.ctx._jvm.Time(milliseconds) @@ -364,8 +364,8 @@ def invReduceFunc(a, b, t): joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) - jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer) - jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, self._ssc._jduration(windowDuration), @@ -384,8 +384,8 @@ def reduceFunc(a, b, t): (k, list(vb), list(va)[0] if len(va) else None)) return g.mapPartitions(lambda x: updateFunc(x) or []) - jreduceFunc = RDDFunction2(self.ctx, reduceFunc, - self.ctx.serializer, self._jrdd_deserializer) + jreduceFunc = RDDFunction(self.ctx, reduceFunc, + self.ctx.serializer, self._jrdd_deserializer) dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) @@ -417,7 +417,8 @@ def _jdstream(self): if self._jdstream_val is not None: return self._jdstream_val - jfunc = RDDFunction(self.ctx, self.func, self.prev._jrdd_deserializer) + func = self.func + jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer) jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc, self.reuse).asJavaDStream() self._jdstream_val = jdstream diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 1684da580f973..06fcc29850504 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -346,17 +346,18 @@ def test_queueStream(self): result = dstream.collect() self.ssc.start() time.sleep(1) - self.assertEqual(input, result) + self.assertEqual(input, result[:3]) def test_union(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - dstream2 = self.ssc.union(dstream, dstream) - result = dstream.collect() + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = dstream3.collect() self.ssc.start() time.sleep(1) expected = [i * 2 for i in input] - self.assertEqual(input, result) + self.assertEqual(expected, result[:3]) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index feff1b3889c49..02b51dc472c51 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,44 +20,13 @@ class RDDFunction(object): """ - This class is for py4j callback. This class is related with - org.apache.spark.streaming.api.python.PythonRDDFunction. + This class is for py4j callback. """ - def __init__(self, ctx, func, jrdd_deserializer): + def __init__(self, ctx, func, deserializer, deserializer2=None): self.ctx = ctx self.func = func - self.deserializer = jrdd_deserializer - - def call(self, jrdd, milliseconds): - try: - emptyRDD = getattr(self.ctx, "_emptyRDD", None) - if emptyRDD is None: - self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() - rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD - r = self.func(rdd, milliseconds) - if r: - return r._jrdd - except: - import traceback - traceback.print_exc() - - def __repr__(self): - return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func)) - - class Java: - implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] - - -class RDDFunction2(object): - """ - This class is for py4j callback. This class is related with - org.apache.spark.streaming.api.python.PythonRDDFunction2. - """ - def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None): - self.ctx = ctx - self.func = func - self.jrdd_deserializer = jrdd_deserializer - self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer + self.deserializer = deserializer + self.deserializer2 = deserializer2 or deserializer def call(self, jrdd, jrdd2, milliseconds): try: @@ -65,12 +34,12 @@ def call(self, jrdd, jrdd2, milliseconds): if emptyRDD is None: self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() - rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD - other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD + rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD + other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD r = self.func(rdd, other, milliseconds) if r: return r._jrdd - except: + except Exception: import traceback traceback.print_exc() @@ -78,7 +47,7 @@ def __repr__(self): return "RDDFunction2(%s)" % (str(self.func)) class Java: - implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2'] + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] def rddToFileName(prefix, suffix, time): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 828a620e4c08f..c0a1aa71840a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -19,8 +19,7 @@ package org.apache.spark.streaming.api.python import java.util.{ArrayList => JArrayList} -import org.apache.spark.Partitioner -import org.apache.spark.rdd.{CoGroupedRDD, UnionRDD, PartitionerAwareUnionRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.api.java._ import org.apache.spark.api.python._ import org.apache.spark.storage.StorageLevel @@ -28,41 +27,14 @@ import org.apache.spark.streaming.{Interval, Duration, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.api.java._ -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - - -/** - * Interface for Python callback function with two arguments - */ -trait PythonRDDFunction { - def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] -} - -class RDDFunction(pfunc: PythonRDDFunction) { - def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val jrdd = if (rdd.isDefined) { - JavaRDD.fromRDD(rdd.get) - } else { - null - } - val r = pfunc.call(jrdd, time.milliseconds) - if (r != null) { - Some(r.rdd) - } else { - None - } - } -} - /** * Interface for Python callback function with three arguments */ -trait PythonRDDFunction2 { +trait PythonRDDFunction { def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } -class RDDFunction2(pfunc: PythonRDDFunction2) { +class RDDFunction(pfunc: PythonRDDFunction) { def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val jrdd = if (rdd.isDefined) { JavaRDD.fromRDD(rdd.get) @@ -114,7 +86,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python if (reuse && lastResult != null) { Some(lastResult.copyTo(rdd1.get)) } else { - val r = func(rdd1, validTime) + val r = func(rdd1, None, validTime) if (reuse && r.isDefined && lastResult == null) { r.get match { case rdd: PythonRDD => @@ -137,10 +109,10 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python */ private[spark] class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], - pfunc: PythonRDDFunction2) + pfunc: PythonRDDFunction) extends DStream[Array[Byte]] (parent.ssc) { - val func = new RDDFunction2(pfunc) + val func = new RDDFunction(pfunc) override def slideDuration: Duration = parent.slideDuration @@ -157,10 +129,10 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[spark] -class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction2) +class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction) extends PythonDStream(parent) { - val reduceFunc = new RDDFunction2(preduceFunc) + val reduceFunc = new RDDFunction(preduceFunc) super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -177,12 +149,12 @@ class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFun } /** - * Copied from ReducedWindowedDStream + * similar to ReducedWindowedDStream */ private[spark] class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], - preduceFunc: PythonRDDFunction2, - pinvReduceFunc: PythonRDDFunction2, + preduceFunc: PythonRDDFunction, + pinvReduceFunc: PythonRDDFunction, _windowDuration: Duration, _slideDuration: Duration ) extends PythonStateDStream(parent, preduceFunc) { @@ -197,7 +169,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) - val invReduceFunc = new RDDFunction2(pinvReduceFunc) + val invReduceFunc = new RDDFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration override def slideDuration: Duration = _slideDuration @@ -209,12 +181,6 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], currentTime) val previousWindow = currentWindow - slideDuration - logDebug("Window time = " + windowDuration) - logDebug("Slide time = " + slideDuration) - logDebug("ZeroTime = " + zeroTime) - logDebug("Current window = " + currentWindow) - logDebug("Previous window = " + previousWindow) - // _____________________________ // | previous window _________|___________________ // |___________________| current window | --------------> Time @@ -271,7 +237,7 @@ class PythonForeachDStream( prev, (rdd: RDD[Array[Byte]], time: Time) => { if (rdd != null) { - foreachFunction.call(rdd, time.milliseconds) + foreachFunction.call(rdd, null, time.milliseconds) } } ) { @@ -283,7 +249,6 @@ class PythonForeachDStream( /** * similar to QueueInputStream */ - class PythonDataInputStream( ssc_ : JavaStreamingContext, inputRDDs: JArrayList[JavaRDD[Array[Byte]]], @@ -294,7 +259,7 @@ class PythonDataInputStream( val emptyRDD = if (defaultRDD != null) { Some(defaultRDD.rdd) } else { - None // ssc.sparkContext.emptyRDD[Array[Byte]] + Some(ssc.sparkContext.emptyRDD[Array[Byte]]) } def start() {} From 7001b5136fdd462af33b62a132e87bf302911082 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 21:50:58 -0700 Subject: [PATCH 311/347] refactor of queueStream() --- python/pyspark/streaming/context.py | 11 ++-- .../streaming/api/python/PythonDStream.scala | 55 ++++--------------- 2 files changed, 19 insertions(+), 47 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index a647c9ec734df..00a1ec6f31fec 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -184,7 +184,7 @@ def _check_serialzers(self, rdds): # reset them to sc.serializer rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) - def queueStream(self, queue, oneAtATime=False, default=None): + def queueStream(self, queue, oneAtATime=True, default=None): """ Create an input stream from an queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. @@ -200,9 +200,12 @@ def queueStream(self, queue, oneAtATime=False, default=None): self._check_serialzers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) - jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime, - default and default._jrdd) - return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) def transform(self, dstreams, transformFunc): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index c0a1aa71840a5..d7dd0a0c5c88b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.api.python import java.util.{ArrayList => JArrayList} +import scala.collection.JavaConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.api.java._ @@ -65,6 +66,16 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } +object PythonDStream { + + // convert list of RDD into queue of RDDs, for ssc.queueStream() + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + /** * Transformed DStream in Python. * @@ -243,46 +254,4 @@ class PythonForeachDStream( ) { this.register() -} - - -/** - * similar to QueueInputStream - */ -class PythonDataInputStream( - ssc_ : JavaStreamingContext, - inputRDDs: JArrayList[JavaRDD[Array[Byte]]], - oneAtAtime: Boolean, - defaultRDD: JavaRDD[Array[Byte]] - ) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) { - - val emptyRDD = if (defaultRDD != null) { - Some(defaultRDD.rdd) - } else { - Some(ssc.sparkContext.emptyRDD[Array[Byte]]) - } - - def start() {} - - def stop() {} - - def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val index = ((validTime - zeroTime) / slideDuration - 1).toInt - if (oneAtAtime) { - if (index == 0) { - val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq - Some(ssc.sparkContext.union(rdds)) - } else { - emptyRDD - } - } else { - if (index < inputRDDs.size()) { - Some(inputRDDs.get(index).rdd) - } else { - emptyRDD - } - } - } - - val asJavaDStream = JavaDStream.fromDStream(this) -} +} \ No newline at end of file From fce0ef5ffdf7d43052978a35b238bbc4ee434cc0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 22:41:04 -0700 Subject: [PATCH 312/347] rafactor of foreachRDD() --- python/pyspark/streaming/dstream.py | 3 +- .../streaming/api/python/PythonDStream.scala | 55 ++++++++----------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d41eca020feb1..8a9e2dab7fb07 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -142,7 +142,8 @@ def foreachRDD(self, func): stream and there materialized. """ jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer) - self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) def pprint(self): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d7dd0a0c5c88b..66cf0c968478c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,9 +20,10 @@ package org.apache.spark.streaming.api.python import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ -import org.apache.spark.rdd.RDD import org.apache.spark.api.java._ +import org.apache.spark.api.java.function.{Function2 => JFunction2} import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} import org.apache.spark.streaming.dstream._ @@ -35,19 +36,22 @@ trait PythonRDDFunction { def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } -class RDDFunction(pfunc: PythonRDDFunction) { - def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val jrdd = if (rdd.isDefined) { +class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + apply(rdd, None, time) + } + + def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { + if (rdd.isDefined) { JavaRDD.fromRDD(rdd.get) } else { null } - val jrdd2 = if (rdd2.isDefined) { - JavaRDD.fromRDD(rdd2.get) - } else { - null - } - val r = pfunc.call(jrdd, jrdd2, time.milliseconds) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds) if (r != null) { Some(r.rdd) } else { @@ -66,7 +70,13 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } -object PythonDStream { +private[spark] object PythonDStream { + + // helper function for DStream.foreachRDD(), + // cannot be `foreachRDD`, it will confusing py4j + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = { + jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds)) + } // convert list of RDD into queue of RDDs, for ssc.queueStream() def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { @@ -97,7 +107,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python if (reuse && lastResult != null) { Some(lastResult.copyTo(rdd1.get)) } else { - val r = func(rdd1, None, validTime) + val r = func(rdd1, validTime) if (reuse && r.isDefined && lastResult == null) { r.get match { case rdd: PythonRDD => @@ -206,8 +216,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // Get the RDD of the reduced value of the previous window val previousWindowRDD = getOrCompute(previousWindow.endTime) + // for small window, reduce once will be better than twice if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { - // subtle the values from old RDDs + // subtract the values from old RDDs val oldRDDs = parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) val subbed = if (oldRDDs.size > 0) { @@ -236,22 +247,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], } } } -} - -/** - * This is used for foreachRDD() in Python - */ -class PythonForeachDStream( - prev: DStream[Array[Byte]], - foreachFunction: PythonRDDFunction - ) extends ForEachDStream[Array[Byte]]( - prev, - (rdd: RDD[Array[Byte]], time: Time) => { - if (rdd != null) { - foreachFunction.call(rdd, null, time.milliseconds) - } - } - ) { - - this.register() } \ No newline at end of file From e059ca224d99b017355f62c157f7a71d9f3ec260 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 22:58:29 -0700 Subject: [PATCH 313/347] move check of window into Python --- python/pyspark/streaming/dstream.py | 9 +++++++++ python/pyspark/streaming/tests.py | 6 ++++++ .../spark/streaming/api/python/PythonDStream.scala | 13 ++----------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8a9e2dab7fb07..ffcf70cc854ab 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -354,6 +354,15 @@ def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None) def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration, numPartitions=None): + + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(windowDuration * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if int(slideDuration * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + reduced = self.reduceByKey(func) def reduceFunc(a, b, t): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 06fcc29850504..843d6ee04ca33 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -292,6 +292,12 @@ def func(dstream): [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] self._test_func(input, func, expected) + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def update_state_by_key(self): def updater(it): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 66cf0c968478c..47c3974b61699 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -92,7 +92,8 @@ private[spark] object PythonDStream { * If the result RDD is PythonRDD, then it will cache it as an template for future use, * this can reduce the Python callbacks. */ -private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, +private[spark] +class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) extends PythonDStream(parent) { @@ -180,16 +181,6 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], _slideDuration: Duration ) extends PythonStateDStream(parent, preduceFunc) { - assert(_windowDuration.isMultipleOf(parent.slideDuration), - "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" - ) - - assert(_slideDuration.isMultipleOf(parent.slideDuration), - "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" - ) - val invReduceFunc = new RDDFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration From 847f9b9faba9f9e6af20c9f5e72e68bc9eb52f4d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 28 Sep 2014 00:20:34 -0700 Subject: [PATCH 314/347] add more docs, add first(), take() --- python/pyspark/streaming/context.py | 3 + python/pyspark/streaming/dstream.py | 243 ++++++++++++++++-- python/pyspark/streaming/tests.py | 15 ++ .../streaming/api/python/PythonDStream.scala | 8 +- 4 files changed, 243 insertions(+), 26 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 00a1ec6f31fec..7879d1b7679d9 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -96,6 +96,9 @@ def _initialize_context(self, sc, duration): return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ return self._jvm.Duration(int(seconds * 1000)) @property diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ffcf70cc854ab..acd9f27c46cbe 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -17,6 +17,7 @@ from itertools import chain, ifilter, imap import operator +import time from datetime import datetime from pyspark import RDD @@ -163,6 +164,29 @@ def takeAndPrint(rdd, time): self.foreachRDD(takeAndPrint) + def first(self): + """ + Return the first RDD in the stream. + """ + return self.take(1)[0] + + def take(self, n): + """ + Return the first `n` RDDs in the stream (will start and stop). + """ + rdds = [] + + def take(rdd, _): + if rdd: + rdds.append(rdd) + if len(rdds) == n: + # FIXME: NPE in JVM + self._ssc.stop(False) + self.foreachRDD(take) + self._ssc.start() + self._ssc.awaitTermination() + return rdds + def collect(self): """ Collect each RDDs into the returned list. @@ -289,12 +313,24 @@ def saveAsPickleFile(rdd, time): return self.foreachRDD(saveAsPickleFile) def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of 'this' DStream. + """ return TransformedDStream(self, lambda a, t: func(a), True) def transformWithTime(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of 'this' DStream. + """ return TransformedDStream(self, func, False) def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of 'this' DStream and 'other' DStream. + """ jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) @@ -302,28 +338,114 @@ def transformWith(self, func, other, keepSerializer=False): return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) def repartitions(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + returned DStream has exactly numPartitions partitions. + """ return self.transform(lambda rdd: rdd.repartition(numPartitions)) + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + @param other Another DStream having the same interval (i.e., slideDuration) as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") return self.transformWith(lambda a, b: a.union(b), other, True) - def cogroup(self, other): - return self.transformWith(lambda a, b: a.cogroup(b), other) + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of `this` + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of `this` DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + `other` DStream. - def leftOuterJoin(self, other): - return self.transformWith(lambda a, b: a.leftOuterJion(b), other) + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + return self.transformWith(lambda a, b: a.leftOuterJion(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + `other` DStream. - def rightOuterJoin(self, other): - return self.transformWith(lambda a, b: a.rightOuterJoin(b), other) + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) - def _jtime(self, milliseconds): - return self.ctx._jvm.Time(milliseconds) + def _jtime(self, timestamp): + """ convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self.ctx._jvm.Time(long(timestamp * 1000)) def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] + def _check_window(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._check_window(windowDuration, slideDuration) d = self._ssc._jduration(windowDuration) if slideDuration is None: return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) @@ -331,43 +453,108 @@ def window(self, windowDuration, slideDuration=None): return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc associative reduce function + @param invReduceFunc inverse reduce function of `reduceFunc` + @param windowDuration width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) return reduced.map(lambda (k, v): v) def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, windowDuration, slideDuration) def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions number of partitions of each RDD in the new DStream. + """ keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b, windowDuration, slideDuration, numPartitions) return counted.filter(lambda (k, v): v > 0).count() def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions Number of partitions of each RDD in the new DStream. + """ ls = self.mapValues(lambda x: [x]) grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], windowDuration, slideDuration, numPartitions) return grouped.mapValues(ResultIterable) - def reduceByKeyAndWindow(self, func, invFunc, - windowDuration, slideDuration, numPartitions=None): + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - duration = self._jdstream.dstream().slideDuration().milliseconds() - if int(windowDuration * 1000) % duration != 0: - raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" - % duration) - if int(slideDuration * 1000) % duration != 0: - raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" - % duration) + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + @param reduceFunc associative reduce function + @param invReduceFunc inverse function of `reduceFunc` + @param windowDuration width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions number of partitions of each RDD in the new DStream. + @param filterFunc function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._check_window(windowDuration, slideDuration) reduced = self.reduceByKey(func) def reduceFunc(a, b, t): b = b.reduceByKey(func, numPartitions) - return a.union(b).reduceByKey(func, numPartitions) if a else b + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r def invReduceFunc(a, b, t): b = b.reduceByKey(func, numPartitions) @@ -375,7 +562,12 @@ def invReduceFunc(a, b, t): return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) - jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, self._ssc._jduration(windowDuration), @@ -384,15 +576,20 @@ def invReduceFunc(a, b, t): def updateStateByKey(self, updateFunc, numPartitions=None): """ - :param updateFunc: [(k, vs, s)] -> [(k, s)] + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). + If `s` is None, then `k` will be eliminated. """ def reduceFunc(a, b, t): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: - g = a.cogroup(b).map(lambda (k, (va, vb)): - (k, list(vb), list(va)[0] if len(va) else None)) - return g.mapPartitions(lambda x: updateFunc(x) or []) + g = a.cogroup(b, numPartitions) + g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) + state = g.mapPartitions(lambda x: updateFunc(x)) + return state.filter(lambda (k, v): v is not None) jreduceFunc = RDDFunction(self.ctx, reduceFunc, self.ctx.serializer, self._jrdd_deserializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 843d6ee04ca33..0ef205754bb58 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -89,6 +89,21 @@ def _sort_result_based_on_key(self, outputs): class TestBasicOperations(PySparkStreamingTestCase): + + def test_take(self): + input = [range(i) for i in range(3)] + dstream = self.ssc.queueStream(input) + rdds = dstream.take(3) + self.assertEqual(3, len(rdds)) + for d, rdd in zip(input, rdds): + self.assertEqual(d, rdd.collect()) + + def test_first(self): + input = [range(10)] + dstream = self.ssc.queueStream(input) + rdd = dstream.first() + self.assertEqual(range(10), rdd.collect()) + def test_map(self): """Basic operation test for DStream.map.""" input = [range(1, 5), range(5, 9), range(9, 13)] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 47c3974b61699..16ac1b93b5f22 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -207,8 +207,10 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // Get the RDD of the reduced value of the previous window val previousWindowRDD = getOrCompute(previousWindow.endTime) - // for small window, reduce once will be better than twice - if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { + if (pinvReduceFunc != null && previousWindowRDD.isDefined + // for small window, reduce once will be better than twice + && windowDuration > slideDuration * 5) { + // subtract the values from old RDDs val oldRDDs = parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) @@ -238,4 +240,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], } } } -} \ No newline at end of file +} From b983f0fed06bcbd6e740fbf86af6eb8881e9f3fd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 28 Sep 2014 23:09:26 -0700 Subject: [PATCH 315/347] address comments --- bin/pyspark | 6 +- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/wordcount.py | 2 +- python/pyspark/accumulators.py | 5 ++ python/pyspark/serializers.py | 5 ++ python/pyspark/streaming/dstream.py | 2 +- python/pyspark/streaming/tests.py | 6 -- python/pyspark/streaming/util.py | 5 ++ python/run-tests | 79 ++++++++++--------- 10 files changed, 61 insertions(+), 53 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 5142411e36974..118e6851af7a0 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -87,11 +87,7 @@ export PYSPARK_SUBMIT_ARGS if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 - else - exec "$PYSPARK_PYTHON" $1 - fi + exec "$PYSPARK_PYTHON" $1 exit fi diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 034a90110af76..19cdbe679fd35 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -293,7 +293,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private[spark] class PairwiseRDD(prev: RDD[Array[Byte]]) extends +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { override def getPartitions = prev.partitions override def compute(split: Partition, context: TaskContext) = diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 633e63172bad6..e3b6248c82a12 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -14,7 +14,7 @@ counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ .reduceByKey(lambda a, b: a+b) - counts.pyprint() + counts.pprint() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/wordcount.py index c794711845af0..8c08ff0c89850 100644 --- a/examples/src/main/python/streaming/wordcount.py +++ b/examples/src/main/python/streaming/wordcount.py @@ -15,7 +15,7 @@ counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda x: (x, 1))\ .reduceByKey(lambda a, b: a+b) - counts.pyprint() + counts.pprint() ssc.start() ssc.awaitTermination() diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..9aa3db7ccf1dd 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -256,3 +256,8 @@ def _start_update_server(): thread.daemon = True thread.start() return server + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 94bebc310bad6..e666dd9800256 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -526,3 +526,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index acd9f27c46cbe..2653e75ccbc54 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -410,7 +410,7 @@ def fullOuterJoin(self, other, numPartitions=None): return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) def _jtime(self, timestamp): - """ convert datetime or unix_timestamp into Time + """ Convert datetime or unix_timestamp into Time """ if isinstance(timestamp, datetime): timestamp = time.mktime(timestamp.timetuple()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0ef205754bb58..c547971cd7741 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -29,7 +29,6 @@ from pyspark.context import SparkContext from pyspark.streaming.context import StreamingContext -from pyspark.streaming.duration import Seconds class PySparkStreamingTestCase(unittest.TestCase): @@ -46,11 +45,6 @@ def setUp(self): def tearDown(self): self.ssc.stop() - @classmethod - def tearDownClass(cls): - # Make sure tp shutdown the callback server - SparkContext._gateway._shutdown_callback_server() - def _test_func(self, input, func, expected, sort=False): """ @param input: dataset for the test. This should be list of lists. diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 02b51dc472c51..885411ed63936 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,3 +64,8 @@ def rddToFileName(prefix, suffix, time): return prefix + "-" + str(time) else: return prefix + "-" + str(time) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/run-tests b/python/run-tests index 5aa9212c8adc1..e8796838c22c1 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,6 +48,39 @@ function run_test() { fi } +function run_core_tests() { + run_test "pyspark/conf.py" + run_test "pyspark/context.py" + run_test "pyspark/broadcast.py" + run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + run_test "pyspark/shuffle.py" + run_test "pyspark/rdd.py" + run_test "pyspark/tests.py" +} + +function run_sql_tests() { + run_test "pyspark/sql.py" +} + +function run_mllib_tests() { + run_test "pyspark/mllib/util.py" + run_test "pyspark/mllib/linalg.py" + run_test "pyspark/mllib/classification.py" + run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/recommendation.py" + run_test "pyspark/mllib/regression.py" + run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/tree.py" + run_test "pyspark/mllib/tests.py" +} + +function run_streaming_tests() { + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -60,30 +93,10 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_test "pyspark/rdd.py" -run_test "pyspark/context.py" -run_test "pyspark/conf.py" -run_test "pyspark/sql.py" -# These tests are included in the module-level docs, and so must -# be handled on a higher level rather than within the python file. -export PYSPARK_DOC_TEST=1 -run_test "pyspark/broadcast.py" -run_test "pyspark/accumulators.py" -run_test "pyspark/serializers.py" -unset PYSPARK_DOC_TEST -run_test "pyspark/shuffle.py" -run_test "pyspark/tests.py" -run_test "pyspark/mllib/classification.py" -run_test "pyspark/mllib/clustering.py" -run_test "pyspark/mllib/linalg.py" -run_test "pyspark/mllib/random.py" -run_test "pyspark/mllib/recommendation.py" -run_test "pyspark/mllib/regression.py" -run_test "pyspark/mllib/stat.py" -run_test "pyspark/mllib/tests.py" -run_test "pyspark/mllib/tree.py" -run_test "pyspark/mllib/util.py" -run_test "pyspark/streaming/tests.py" +#run_core_tests +#run_sql_tests +#run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -91,20 +104,10 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - run_test "pyspark/sql.py" - # These tests are included in the module-level docs, and so must - # be handled on a higher level rather than within the python file. - export PYSPARK_DOC_TEST=1 - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - unset PYSPARK_DOC_TEST - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" - run_test "pyspark/streaming/tests.py" + run_core_tests + run_sql_tests + run_mllib_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then From 98ac6c26d63dde9b6ca75177e082dbc421998ef7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 11:01:39 -0700 Subject: [PATCH 316/347] support ssc.transform() --- python/pyspark/streaming/context.py | 18 +++++-- python/pyspark/streaming/dstream.py | 36 +++++++------- python/pyspark/streaming/tests.py | 13 +++++ python/pyspark/streaming/util.py | 26 +++++----- .../spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 49 +++++++++++++------ 6 files changed, 96 insertions(+), 48 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 7879d1b7679d9..ce8aec613d08b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,6 +20,7 @@ from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import RDDFunction from py4j.java_collections import ListConverter from py4j.java_gateway import java_import @@ -212,11 +213,20 @@ def queueStream(self, queue, oneAtATime=True, default=None): def transform(self, dstreams, transformFunc): """ - Create a new DStream in which each RDD is generated by applying a function on RDDs of - the DStreams. The order of the JavaRDDs in the transform function parameter will be the - same as the order of corresponding DStreams in the list. + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. """ - # TODO + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + jfunc = RDDFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + + jdstream = self._jvm.PythonDStream.callTransform(self._jssc, jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) def union(self, *dstreams): """ diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2653e75ccbc54..ae5be72952c76 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -132,7 +132,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) def foreach(self, func): - return self.foreachRDD(lambda rdd, _: rdd.foreach(func)) + return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) def foreachRDD(self, func): """ @@ -142,7 +142,7 @@ def foreachRDD(self, func): This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -151,10 +151,10 @@ def pprint(self): Print the first ten elements of each RDD generated in this DStream. This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - def takeAndPrint(rdd, time): + def takeAndPrint(timestamp, rdd): taken = rdd.take(11) print "-------------------------------------------" - print "Time: %s" % datetime.fromtimestamp(time / 1000.0) + print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0) print "-------------------------------------------" for record in taken[:10]: print record @@ -176,15 +176,15 @@ def take(self, n): """ rdds = [] - def take(rdd, _): - if rdd: + def take(_, rdd): + if rdd and len(rdds) < n: rdds.append(rdd) - if len(rdds) == n: - # FIXME: NPE in JVM - self._ssc.stop(False) self.foreachRDD(take) + self._ssc.start() - self._ssc.awaitTermination() + while len(rdds) < n: + time.sleep(0.01) + self._ssc.stop(False, True) return rdds def collect(self): @@ -195,7 +195,7 @@ def collect(self): """ result = [] - def get_output(rdd, time): + def get_output(_, rdd): r = rdd.collect() result.append(r) self.foreachRDD(get_output) @@ -317,7 +317,7 @@ def transform(self, func): Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream. """ - return TransformedDStream(self, lambda a, t: func(a), True) + return TransformedDStream(self, lambda t, a: func(a), True) def transformWithTime(self, func): """ @@ -331,7 +331,7 @@ def transformWith(self, func, other, keepSerializer=False): Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream and 'other' DStream. """ - jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer @@ -549,14 +549,14 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None self._check_window(windowDuration, slideDuration) reduced = self.reduceByKey(func) - def reduceFunc(a, b, t): + def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) r = a.union(b).reduceByKey(func, numPartitions) if a else b if filterFunc: r = r.filter(filterFunc) return r - def invReduceFunc(a, b, t): + def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) @@ -582,7 +582,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None): @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). If `s` is None, then `k` will be eliminated. """ - def reduceFunc(a, b, t): + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: @@ -610,7 +610,7 @@ def __init__(self, prev, func, reuse=False): not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func old_func = func - func = lambda rdd, t: old_func(prev_func(rdd, t), t) + func = lambda t, rdd: old_func(t, prev_func(t, rdd)) reuse = reuse and prev.reuse prev = prev.prev @@ -625,7 +625,7 @@ def _jdstream(self): return self._jdstream_val func = self.func - jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer) jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc, self.reuse).asJavaDStream() self._jdstream_val = jdstream diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index c547971cd7741..ecf88cce47beb 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -374,6 +374,19 @@ def test_union(self): expected = [i * 2 for i in input] self.assertEqual(expected, result[:3]) + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], dstream.first().collect()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 885411ed63936..57791805e8f9f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -22,21 +22,25 @@ class RDDFunction(object): """ This class is for py4j callback. """ - def __init__(self, ctx, func, deserializer, deserializer2=None): + def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func - self.deserializer = deserializer - self.deserializer2 = deserializer2 or deserializer + self.deserializers = deserializers + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + self.emptyRDD = emptyRDD - def call(self, jrdd, jrdd2, milliseconds): + def call(self, milliseconds, jrdds): try: - emptyRDD = getattr(self.ctx, "_emptyRDD", None) - if emptyRDD is None: - self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) - rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD - other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD - r = self.func(rdd, other, milliseconds) + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD + for jrdd, ser in zip(jrdds, sers)] + r = self.func(milliseconds, *rdds) if r: return r._jrdd except Exception: @@ -44,7 +48,7 @@ def call(self, jrdd, jrdd2, milliseconds): traceback.print_exc() def __repr__(self): - return "RDDFunction2(%s)" % (str(self.func)) + return "RDDFunction(%s)" % (str(self.func)) class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5a8eef1372e23..ab6a6de074a80 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -413,7 +413,7 @@ class StreamingContext private[streaming] ( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] ): DStream[T] = { - new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) + new TransformedDStream[T](dstreams, (transformFunc)) } /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 16ac1b93b5f22..8ba8c0441ef35 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.api.python -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.api.java._ -import org.apache.spark.api.java.function.{Function2 => JFunction2} import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -29,18 +30,19 @@ import org.apache.spark.streaming.{Interval, Duration, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.api.java._ + /** * Interface for Python callback function with three arguments */ trait PythonRDDFunction { - def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } -class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { - - def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - apply(rdd, None, time) - } +/** + * Wrapper for PythonRDDFunction + */ +class RDDFunction(pfunc: PythonRDDFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { if (rdd.isDefined) { @@ -50,14 +52,25 @@ class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { } } - def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds) - if (r != null) { - Some(r.rdd) + def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { + if (jrdd != null) { + Some(jrdd.rdd) } else { None } } + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava)) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava)) + } + + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } } private[python] @@ -74,8 +87,16 @@ private[spark] object PythonDStream { // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = { - jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds)) + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){ + val func = new RDDFunction(pyfunc) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + // helper function for ssc.transform() + def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction) + :JavaDStream[Array[Byte]] = { + val func = new RDDFunction(pyfunc) + ssc.transform(jdsteams, func) } // convert list of RDD into queue of RDDs, for ssc.queueStream() From c40c52df9fd8b6dc8fd44196a73d57bd97a43a06 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 11:10:05 -0700 Subject: [PATCH 317/347] change first(), take(n) to has the same behavior as RDD --- python/pyspark/streaming/dstream.py | 11 ++++++----- python/pyspark/streaming/tests.py | 10 +++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ae5be72952c76..8f02d95e03d43 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -174,18 +174,19 @@ def take(self, n): """ Return the first `n` RDDs in the stream (will start and stop). """ - rdds = [] + results = [] def take(_, rdd): - if rdd and len(rdds) < n: - rdds.append(rdd) + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + self.foreachRDD(take) self._ssc.start() - while len(rdds) < n: + while len(results) < n: time.sleep(0.01) self._ssc.stop(False, True) - return rdds + return results def collect(self): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index ecf88cce47beb..828c40f247629 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -87,16 +87,12 @@ class TestBasicOperations(PySparkStreamingTestCase): def test_take(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - rdds = dstream.take(3) - self.assertEqual(3, len(rdds)) - for d, rdd in zip(input, rdds): - self.assertEqual(d, rdd.collect()) + self.assertEqual([0, 0, 1], dstream.take(3)) def test_first(self): input = [range(10)] dstream = self.ssc.queueStream(input) - rdd = dstream.first() - self.assertEqual(range(10), rdd.collect()) + self.assertEqual(0, dstream) def test_map(self): """Basic operation test for DStream.map.""" @@ -385,7 +381,7 @@ def func(rdds): dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - self.assertEqual([2, 3, 1], dstream.first().collect()) + self.assertEqual([2, 3, 1], dstream.take(3)) if __name__ == "__main__": From 6ebceca528dbd94dc23eba4412715e661ff6527e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 13:26:06 -0700 Subject: [PATCH 318/347] add more tests --- python/pyspark/streaming/dstream.py | 8 +- python/pyspark/streaming/tests.py | 156 +++++++++++++----- .../streaming/api/python/PythonDStream.scala | 34 ++-- 3 files changed, 137 insertions(+), 61 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8f02d95e03d43..c18c68dfe5a32 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -286,7 +286,7 @@ def saveAsTextFiles(self, prefix, suffix=None): Save this DStream as a text file, using string representations of elements. """ - def saveAsTextFile(rdd, time): + def saveAsTextFile(time, rdd): """ Closure to save element in RDD in DStream as Pickled data in file. This closure is called by py4j callback server. @@ -303,7 +303,7 @@ def saveAsPickleFiles(self, prefix, suffix=None): is 10. """ - def saveAsPickleFile(rdd, time): + def saveAsPickleFile(time, rdd): """ Closure to save element in RDD in the DStream as Pickled data in file. This closure is called by py4j callback server. @@ -388,7 +388,7 @@ def leftOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ - return self.transformWith(lambda a, b: a.leftOuterJion(b, numPartitions), other) + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) def rightOuterJoin(self, other, numPartitions=None): """ @@ -502,7 +502,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non @param numPartitions number of partitions of each RDD in the new DStream. """ keyed = self.map(lambda x: (x, 1)) - counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b, + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) return counted.filter(lambda (k, v): v > 0).count() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 828c40f247629..54d4d9b1f7850 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,17 +15,12 @@ # limitations under the License. # -""" -Unit tests for Python SparkStreaming; additional tests are implemented as doctests in -individual modules. - -Callback server is sometimes unstable sometimes, which cause error in test case. -But this is very rare case. -""" +import os from itertools import chain import time import operator import unittest +import tempfile from pyspark.context import SparkContext from pyspark.streaming.context import StreamingContext @@ -45,16 +40,20 @@ def setUp(self): def tearDown(self): self.ssc.stop() - def _test_func(self, input, func, expected, sort=False): + def _test_func(self, input, func, expected, sort=False, input2=None): """ @param input: dataset for the test. This should be list of lists. @param func: wrapped function. This function should return PythonDStream object. @param expected: expected output for this testcase. """ - # Generate input stream with user-defined input. input_stream = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None # Apply test function to stream. - stream = func(input_stream) + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + result = stream.collect() self.ssc.start() @@ -92,7 +91,7 @@ def test_take(self): def test_first(self): input = [range(10)] dstream = self.ssc.queueStream(input) - self.assertEqual(0, dstream) + self.assertEqual(0, dstream.first()) def test_map(self): """Basic operation test for DStream.map.""" @@ -238,55 +237,122 @@ def add(a, b): [("a", "11"), ("b", "1"), ("", "111")]] self._test_func(input, func, expected, sort=True) + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartitions(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + def test_union(self): - input1 = [range(3), range(5), range(1), range(6)] - input2 = [range(3, 6), range(5, 6), range(1, 6)] + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] - d1 = self.ssc.queueStream(input1) - d2 = self.ssc.queueStream(input2) - d = d1.union(d2) - result = d.collect() - expected = [range(6), range(6), range(6), range(6)] + def func(d1, d2): + return d1.union(d2) - self.ssc.start() - start_time = time.time() - # Loop until get the expected the number of the result from the stream. - while True: - current_time = time.time() - # Check time out. - if (current_time - start_time) > self.timeout * 2: - break - # StreamingContext.awaitTermination is not used to wait because - # if py4j server is called every 50 milliseconds, it gets an error. - time.sleep(0.05) - # Check if the output is the same length of expected output. - if len(expected) == len(result): - break - self.assertEqual(expected, result) + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) class TestWindowFunctions(PySparkStreamingTestCase): - timeout = 15 + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) def test_count_by_window(self): - input = [range(1), range(2), range(3), range(4), range(5), range(6)] + input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(4, 1) + return dstream.countByWindow(3, 1) - expected = [[1], [3], [6], [9], [12], [15], [11], [6]] + expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(6, 1) + return dstream.countByWindow(5, 1) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(6, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] @@ -359,6 +425,20 @@ def test_queueStream(self): time.sleep(1) self.assertEqual(input, result[:3]) + # TODO: test textFileStream + # def test_textFileStream(self): + # input = [range(i) for i in range(3)] + # dstream = self.ssc.queueStream(input) + # d = os.path.join(tempfile.gettempdir(), str(id(self))) + # if not os.path.exists(d): + # os.makedirs(d) + # dstream.saveAsTextFiles(os.path.join(d, 'test')) + # dstream2 = self.ssc.textFileStream(d) + # result = dstream2.collect() + # self.ssc.start() + # time.sleep(2) + # self.assertEqual(input, result[:3]) + def test_union(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 8ba8c0441ef35..2f20b05991b8e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -93,7 +93,8 @@ private[spark] object PythonDStream { } // helper function for ssc.transform() - def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction) + def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], + pyfunc: PythonRDDFunction) :JavaDStream[Array[Byte]] = { val func = new RDDFunction(pyfunc) ssc.transform(jdsteams, func) @@ -210,9 +211,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { val currentTime = validTime - val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, + val current = new Interval(currentTime - windowDuration, currentTime) - val previousWindow = currentWindow - slideDuration + val previous = current - slideDuration // _____________________________ // | previous window _________|___________________ @@ -225,35 +226,30 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // old RDDs new RDDs // - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime) + val previousRDD = getOrCompute(previous.endTime) - if (pinvReduceFunc != null && previousWindowRDD.isDefined + if (pinvReduceFunc != null && previousRDD.isDefined // for small window, reduce once will be better than twice - && windowDuration > slideDuration * 5) { + && windowDuration >= slideDuration * 5) { // subtract the values from old RDDs - val oldRDDs = - parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) - val subbed = if (oldRDDs.size > 0) { - invReduceFunc(previousWindowRDD, Some(ssc.sc.union(oldRDDs)), validTime) + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) } else { - previousWindowRDD + previousRDD } // add the RDDs of the reduced values in "new time steps" - val newRDDs = - parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) - + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) if (newRDDs.size > 0) { - reduceFunc(subbed, Some(ssc.sc.union(newRDDs)), validTime) + reduceFunc(subtracted, Some(ssc.sc.union(newRDDs)), validTime) } else { - subbed + subtracted } } else { // Get the RDDs of the reduced values in current window - val currentRDDs = - parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) if (currentRDDs.size > 0) { reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime) } else { From 19797f9fc9b062ee30746c184ad432192ca5e19a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 13:41:44 -0700 Subject: [PATCH 319/347] clean up --- python/pyspark/streaming/context.py | 6 +++--- python/pyspark/streaming/tests.py | 4 ++-- .../scala/org/apache/spark/streaming/StreamingContext.scala | 2 +- .../spark/streaming/api/java/JavaStreamingContext.scala | 4 ---- .../apache/spark/streaming/api/python/PythonDStream.scala | 3 ++- 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ce8aec613d08b..425b0a96aa832 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,6 +15,9 @@ # limitations under the License. # +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import + from pyspark import RDD from pyspark.serializers import UTF8Deserializer from pyspark.context import SparkContext @@ -22,9 +25,6 @@ from pyspark.streaming.dstream import DStream from pyspark.streaming.util import RDDFunction -from py4j.java_collections import ListConverter -from py4j.java_gateway import java_import - __all__ = ["StreamingContext"] diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 54d4d9b1f7850..342afde3bffd2 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -348,7 +348,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(6, 1) + return dstream.countByValueAndWindow(5, 1) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -357,7 +357,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(4, 1).mapValues(list) + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ab6a6de074a80..ef7631788f26d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -413,7 +413,7 @@ class StreamingContext private[streaming] ( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] ): DStream[T] = { - new TransformedDStream[T](dstreams, (transformFunc)) + new TransformedDStream[T](dstreams, transformFunc) } /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 662cd8d22c6a5..9dc26dc6b32a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -549,10 +549,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * JavaStreamingContext object contains a number of utility functions. */ object JavaStreamingContext { - implicit def fromStreamingContext(ssc: StreamingContext): - JavaStreamingContext = new JavaStreamingContext(ssc) - - implicit def toStreamingContext(jssc: JavaStreamingContext): StreamingContext = jssc.ssc /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 2f20b05991b8e..30c52c15e9e68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -41,7 +41,7 @@ trait PythonRDDFunction { /** * Wrapper for PythonRDDFunction */ -class RDDFunction(pfunc: PythonRDDFunction) +private[python] class RDDFunction(pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { @@ -68,6 +68,7 @@ class RDDFunction(pfunc: PythonRDDFunction) some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava)) } + // for JFunction2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { pfunc.call(time.milliseconds, rdds) } From 338580a7aa39fcf8beedefdc7000b906a1028c84 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 17:02:12 -0700 Subject: [PATCH 320/347] change _first(), _take(), _collect() as private API --- python/pyspark/streaming/dstream.py | 8 ++++---- python/pyspark/streaming/tests.py | 23 ++++++++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index c18c68dfe5a32..d98afc3e5a294 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -164,13 +164,13 @@ def takeAndPrint(timestamp, rdd): self.foreachRDD(takeAndPrint) - def first(self): + def _first(self): """ Return the first RDD in the stream. """ - return self.take(1)[0] + return self._take(1)[0] - def take(self, n): + def _take(self, n): """ Return the first `n` RDDs in the stream (will start and stop). """ @@ -188,7 +188,7 @@ def take(_, rdd): self._ssc.stop(False, True) return results - def collect(self): + def _collect(self): """ Collect each RDDs into the returned list. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 342afde3bffd2..7ffdb145c104e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -54,7 +54,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): else: stream = func(input_stream) - result = stream.collect() + result = stream._collect() self.ssc.start() start_time = time.time() @@ -86,12 +86,12 @@ class TestBasicOperations(PySparkStreamingTestCase): def test_take(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - self.assertEqual([0, 0, 1], dstream.take(3)) + self.assertEqual([0, 0, 1], dstream._take(3)) def test_first(self): input = [range(10)] dstream = self.ssc.queueStream(input) - self.assertEqual(0, dstream.first()) + self.assertEqual(0, dstream._first()) def test_map(self): """Basic operation test for DStream.map.""" @@ -415,17 +415,17 @@ def _addInputStream(self): # Make sure each length of input is over 3 inputs = map(lambda x: range(1, x), range(5, 101)) stream = self.ssc.queueStream(inputs) - stream.collect() + stream._collect() def test_queueStream(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - result = dstream.collect() + result = dstream._collect() self.ssc.start() time.sleep(1) self.assertEqual(input, result[:3]) - # TODO: test textFileStream + # TODO: fix this test # def test_textFileStream(self): # input = [range(i) for i in range(3)] # dstream = self.ssc.queueStream(input) @@ -433,8 +433,13 @@ def test_queueStream(self): # if not os.path.exists(d): # os.makedirs(d) # dstream.saveAsTextFiles(os.path.join(d, 'test')) + # self.ssc.start() + # time.sleep(1) + # self.ssc.stop(False, True) + # + # self.ssc = StreamingContext(self.sc, self.batachDuration) # dstream2 = self.ssc.textFileStream(d) - # result = dstream2.collect() + # result = dstream2._collect() # self.ssc.start() # time.sleep(2) # self.assertEqual(input, result[:3]) @@ -444,7 +449,7 @@ def test_union(self): dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) - result = dstream3.collect() + result = dstream3._collect() self.ssc.start() time.sleep(1) expected = [i * 2 for i in input] @@ -461,7 +466,7 @@ def func(rdds): dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - self.assertEqual([2, 3, 1], dstream.take(3)) + self.assertEqual([2, 3, 1], dstream._take(3)) if __name__ == "__main__": From 069a94c2f12211560691177f465a74630531e81b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Sep 2014 22:48:47 -0700 Subject: [PATCH 321/347] fix the number of partitions during window() --- python/pyspark/streaming/dstream.py | 12 +++++++++--- python/pyspark/streaming/tests.py | 8 +++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d98afc3e5a294..d866f8c9687fb 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + r = a.union(b).reduceByKey(func, partitions) if a else b if filterFunc: r = r.filter(filterFunc) return r def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + joined = a.leftOuterJoin(b, partitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) @@ -587,7 +591,9 @@ def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: - g = a.cogroup(b, numPartitions) + # use the average of number of partitions, or it will keep increasing + partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 + g = a.cogroup(b, partitions) g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7ffdb145c104e..0dc6b3d675397 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -22,7 +22,7 @@ import unittest import tempfile -from pyspark.context import SparkContext +from pyspark.context import SparkContext, RDD from pyspark.streaming.context import StreamingContext @@ -46,8 +46,13 @@ def _test_func(self, input, func, expected, sort=False, input2=None): @param func: wrapped function. This function should return PythonDStream object. @param expected: expected output for this testcase. """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + # Apply test function to stream. if input2: stream = func(input_stream, input_stream2) @@ -63,6 +68,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): current_time = time.time() # Check time out. if (current_time - start_time) > self.timeout: + print "timeout after", self.timeout break # StreamingContext.awaitTermination is not used to wait because # if py4j server is called every 50 milliseconds, it gets an error. From e00136b3dfd330689d89e44006a49871b36a4825 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 00:41:39 -0700 Subject: [PATCH 322/347] address comments --- .../apache/spark/api/python/PythonRDD.scala | 1 + .../{wordcount.py => hdfs_wordcount.py} | 0 python/pyspark/java_gateway.py | 1 - python/pyspark/streaming/context.py | 32 ++- python/pyspark/streaming/dstream.py | 254 ++++++++---------- python/pyspark/streaming/tests.py | 69 +++-- python/pyspark/streaming/util.py | 5 +- python/run-tests | 6 +- .../streaming/api/python/PythonDStream.scala | 98 ++++--- 9 files changed, 245 insertions(+), 221 deletions(-) rename examples/src/main/python/streaming/{wordcount.py => hdfs_wordcount.py} (100%) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19cdbe679fd35..8051b221ac3d1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -52,6 +52,7 @@ private[spark] class PythonRDD( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + // create a new PythonRDD with same Python setting but different parent. def copyTo(rdd: RDD[_]): PythonRDD = { new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator) diff --git a/examples/src/main/python/streaming/wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py similarity index 100% rename from examples/src/main/python/streaming/wordcount.py rename to examples/src/main/python/streaming/hdfs_wordcount.py diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index db5b97f8472d1..9c70fa5c16d0c 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -23,7 +23,6 @@ import platform from subprocess import Popen, PIPE from threading import Thread - from py4j.java_gateway import java_import, JavaGateway, GatewayClient diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 425b0a96aa832..ae4a1d5b6b069 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -31,6 +31,11 @@ def _daemonize_callback_server(): """ Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. """ # TODO: create a patch for Py4J import socket @@ -47,7 +52,6 @@ def start(self): 1) try: self.server_socket.bind((self.address, self.port)) - # self.port = self.server_socket.getsockname()[1] except Exception: msg = 'An error occurred while trying to start the callback server' logger.exception(msg) @@ -63,19 +67,21 @@ def start(self): class StreamingContext(object): """ - Main entry point for Spark Streaming functionality. A StreamingContext represents the - connection to a Spark cluster, and can be used to create L{DStream}s and - broadcast variables on that cluster. + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream}s various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. """ def __init__(self, sparkContext, duration): """ - Create a new StreamingContext. At least the master and app name and duration - should be set, either through the named parameters here or through C{conf}. + Create a new StreamingContext. @param sparkContext: L{SparkContext} object. - @param duration: seconds for SparkStreaming. - + @param duration: number of seconds. """ self._sc = sparkContext self._jvm = self._sc._jvm @@ -127,8 +133,12 @@ def awaitTermination(self, timeout=None): def stop(self, stopSparkContext=True, stopGraceFully=False): """ - Stop the execution of the streams immediately (does not wait for all received data - to be processed). + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext Stop the associated SparkContext or not + @param stopGracefully Stop gracefully by waiting for the processing + of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) if stopSparkContext: @@ -140,7 +150,7 @@ def remember(self, duration): in the last given duration. DStreams remember RDDs only for a limited duration of time and releases them for garbage collection. This method allows the developer to specify how to long to remember - the RDDs ( if the developer wishes to query old data outside the + the RDDs (if the developer wishes to query old data outside the DStream computation). @param duration Minimum duration (in seconds) that each DStream diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index d866f8c9687fb..4e3f07e26953b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -30,6 +30,24 @@ class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ def __init__(self, jdstream, ssc, jrdd_deserializer): self._jdstream = jdstream self._ssc = ssc @@ -46,11 +64,12 @@ def context(self): def count(self): """ - Return a new DStream which contains the number of elements in this DStream. + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() + return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum() - def sum(self): + def _sum(self): """ Add up the elements in this DStream. """ @@ -66,8 +85,8 @@ def func(iterator): def flatMap(self, f, preservesPartitioning=False): """ - Pass each value in the key-value pair DStream through flatMap function - without changing the keys: this also retains the original RDD's partition. + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results """ def func(s, iterator): return chain.from_iterable(imap(f, iterator)) @@ -83,7 +102,8 @@ def func(iterator): def mapPartitions(self, f, preservesPartitioning=False): """ - Return a new DStream by applying a function to each partition of this DStream. + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. """ def func(s, iterator): return f(iterator) @@ -91,56 +111,51 @@ def func(s, iterator): def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ - Return a new DStream by applying a function to each partition of this DStream, - while tracking the index of the original partition. + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. """ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) def reduce(self, func): """ - Return a new DStream by reduceing the elements of this RDD using the specified - commutative and associative binary operator. + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. """ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ - Merge the value for each key using an associative reduce function. - - This will also perform the merging locally on each mapper before - sending results to reducer, similarly to a "combiner" in MapReduce. - - Output will be hash-partitioned with C{numPartitions} partitions, or - the default parallelism level if C{numPartitions} is not specified. + Return a new DStream by applying reduceByKey to each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.combineByKey(lambda x: x, func, func, numPartitions) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions=None): """ - Count the number of elements for each key, and return the result to the - master as a dictionary + Return a new DStream by applying combineByKey to each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + def func(rdd): return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) return self.transform(func) def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ - Return a copy of the DStream partitioned using the specified partitioner. + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. """ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) - def foreach(self, func): - return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) + # def foreach(self, func): + # return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) def foreachRDD(self, func): """ - Apply userdefined function to all RDD in a DStream. - This python implementation could be expensive because it uses callback server - in order to apply function to RDD in DStream. - This is an output operator, so this DStream will be registered as an output - stream and there materialized. + Apply a function to each RDD in this DStream. """ jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream @@ -148,13 +163,12 @@ def foreachRDD(self, func): def pprint(self): """ - Print the first ten elements of each RDD generated in this DStream. This is an output - operator, so this DStream will be registered as an output stream and there materialized. + Print the first ten elements of each RDD generated in this DStream. """ - def takeAndPrint(timestamp, rdd): + def takeAndPrint(time, rdd): taken = rdd.take(11) print "-------------------------------------------" - print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0) + print "Time: %s" % time print "-------------------------------------------" for record in taken[:10]: print record @@ -164,58 +178,18 @@ def takeAndPrint(timestamp, rdd): self.foreachRDD(takeAndPrint) - def _first(self): - """ - Return the first RDD in the stream. - """ - return self._take(1)[0] - - def _take(self, n): - """ - Return the first `n` RDDs in the stream (will start and stop). - """ - results = [] - - def take(_, rdd): - if rdd and len(results) < n: - results.extend(rdd.take(n - len(results))) - - self.foreachRDD(take) - - self._ssc.start() - while len(results) < n: - time.sleep(0.01) - self._ssc.stop(False, True) - return results - - def _collect(self): - """ - Collect each RDDs into the returned list. - - :return: list, which will have the collected items. - """ - result = [] - - def get_output(_, rdd): - r = rdd.collect() - result.append(r) - self.foreachRDD(get_output) - return result - def mapValues(self, f): """ - Pass each value in the key-value pair RDD through a map function - without changing the keys; this also retains the original RDD's - partitioning. + Return a new DStream by applying a map function to the value of + each key-value pairs in 'this' DStream without changing the key. """ map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): """ - Pass each value in the key-value pair RDD through a flatMap function - without changing the keys; this also retains the original RDD's - partitioning. + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in 'this' DStream without changing the key. """ flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) return self.flatMap(flat_map_fn, preservesPartitioning=True) @@ -223,8 +197,7 @@ def flatMapValues(self, f): def glom(self): """ Return a new DStream in which RDD is generated by applying glom() - to RDD of this DStream. Applying glom() to an RDD coalesces all - elements within each partition into an list. + to RDD of this DStream. """ def func(iterator): yield list(iterator) @@ -232,7 +205,8 @@ def func(iterator): def cache(self): """ - Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}). + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). """ self.is_cached = True self.persist(StorageLevel.MEMORY_ONLY_SER) @@ -240,9 +214,7 @@ def cache(self): def persist(self, storageLevel): """ - Set this DStream's storage level to persist its values across operations - after the first time it is computed. This can only be used to assign - a new storage level if the DStream does not have a storage level set yet. + Persist the RDDs of this DStream with the given storage level """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -251,11 +223,10 @@ def persist(self, storageLevel): def checkpoint(self, interval): """ - Mark this DStream for checkpointing. It will be saved to a file inside the - checkpoint directory set with L{SparkContext.setCheckpointDir()} + Enable periodic checkpointing of RDDs of this DStream - @param interval: time in seconds, after which generated RDD will - be checkpointed + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed """ self.is_checkpointed = True self._jdstream.checkpoint(self._ssc._jduration(interval)) @@ -263,85 +234,76 @@ def checkpoint(self, interval): def groupByKey(self, numPartitions=None): """ - Return a new DStream which contains group the values for each key in the - DStream into a single sequence. - Hash-partitions the resulting RDD with into numPartitions partitions in - the DStream. - - Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + Return a new DStream by applying groupByKey on each RDD. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) def countByValue(self): """ - Return new DStream which contains the count of each unique value in this - DStreeam as a (value, count) pairs. + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. """ return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() def saveAsTextFiles(self, prefix, suffix=None): """ - Save this DStream as a text file, using string representations of elements. + Save each RDD in this DStream as at text file, using string + representation of elements. """ - def saveAsTextFile(time, rdd): - """ - Closure to save element in RDD in DStream as Pickled data in file. - This closure is called by py4j callback server. - """ path = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(path) - return self.foreachRDD(saveAsTextFile) - def saveAsPickleFiles(self, prefix, suffix=None): + def _saveAsPickleFiles(self, prefix, suffix=None): """ - Save this DStream as a SequenceFile of serialized objects. The serializer - used is L{pyspark.serializers.PickleSerializer}, default batch size - is 10. + Save each RDD in this DStream as at binary file, the elements are + serialized by pickle. """ - def saveAsPickleFile(time, rdd): - """ - Closure to save element in RDD in the DStream as Pickled data in file. - This closure is called by py4j callback server. - """ path = rddToFileName(prefix, suffix, time) rdd.saveAsPickleFile(path) - return self.foreachRDD(saveAsPickleFile) def transform(self, func): """ Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream. - """ - return TransformedDStream(self, lambda t, a: func(a), True) - def transformWithTime(self, func): + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) """ - Return a new DStream in which each RDD is generated by applying a function - on each RDD of 'this' DStream. - """ - return TransformedDStream(self, func, False) + resue = False + if func.func_code.co_argcount == 1: + reuse = True + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func, reuse) def transformWith(self, func, other, keepSerializer=False): """ Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) """ - jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer) + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) - def repartitions(self, numPartitions): + def repartition(self, numPartitions): """ - Return a new DStream with an increased or decreased level of parallelism. Each RDD in the - returned DStream has exactly numPartitions partitions. + Return a new DStream with an increased or decreased level of parallelism. """ return self.transform(lambda rdd: rdd.repartition(numPartitions)) @@ -355,7 +317,8 @@ def _slideDuration(self): def union(self, other): """ Return a new DStream by unifying data of another DStream with this DStream. - @param other Another DStream having the same interval (i.e., slideDuration) as this DStream. + @param other Another DStream having the same interval (i.e., slideDuration) + as this DStream. """ if self._slideDuration != other._slideDuration: raise ValueError("the two DStream should have same slide duration") @@ -368,6 +331,8 @@ def cogroup(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) def join(self, other, numPartitions=None): @@ -378,6 +343,8 @@ def join(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.join(b, numPartitions), other) def leftOuterJoin(self, other, numPartitions=None): @@ -388,6 +355,8 @@ def leftOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) def rightOuterJoin(self, other, numPartitions=None): @@ -398,6 +367,8 @@ def rightOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) def fullOuterJoin(self, other, numPartitions=None): @@ -408,6 +379,8 @@ def fullOuterJoin(self, other, numPartitions=None): Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) def _jtime(self, timestamp): @@ -426,7 +399,7 @@ def slice(self, begin, end): jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] - def _check_window(self, window, slide): + def _validate_window_param(self, window, slide): duration = self._jdstream.dstream().slideDuration().milliseconds() if int(window * 1000) % duration != 0: raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" @@ -446,7 +419,7 @@ def window(self, windowDuration, slideDuration=None): the new DStream will generate RDDs); must be a multiple of this DStream's batching interval """ - self._check_window(windowDuration, slideDuration) + self._validate_window_param(windowDuration, slideDuration) d = self._ssc._jduration(windowDuration) if slideDuration is None: return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) @@ -547,23 +520,22 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None only pairs that satisfy the function are retained set this to null if you do not want to filter """ - self._check_window(windowDuration, slideDuration) - reduced = self.reduceByKey(func) + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - r = a.union(b).reduceByKey(func, partitions) if a else b + r = a.union(b).reduceByKey(func, numPartitions) if a else b if filterFunc: r = r.filter(filterFunc) return r def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - joined = a.leftOuterJoin(b, partitions) + joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) @@ -587,13 +559,14 @@ def updateStateByKey(self, updateFunc, numPartitions=None): @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). If `s` is None, then `k` will be eliminated. """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: - # use the average of number of partitions, or it will keep increasing - partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2 - g = a.cogroup(b, partitions) + g = a.cogroup(b, numPartitions) g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) @@ -605,6 +578,13 @@ def reduceFunc(t, a, b): class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ def __init__(self, prev, func, reuse=False): ssc = prev._ssc self._ssc = ssc diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0dc6b3d675397..698978e61ffad 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -29,17 +29,50 @@ class PySparkStreamingTestCase(unittest.TestCase): timeout = 10 # seconds + duration = 1 def setUp(self): class_name = self.__class__.__name__ self.sc = SparkContext(appName=class_name) self.sc.setCheckpointDir("/tmp") # TODO: decrease duration to speed up tests - self.ssc = StreamingContext(self.sc, duration=1) + self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): self.ssc.stop() + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + while len(results) < n: + time.sleep(0.01) + self.ssc.stop(False, True) + return results + + def _collect(self, dstream): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + r = rdd.collect() + result.append(r) + dstream.foreachRDD(get_output) + return result + def _test_func(self, input, func, expected, sort=False, input2=None): """ @param input: dataset for the test. This should be list of lists. @@ -59,7 +92,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): else: stream = func(input_stream) - result = stream._collect() + result = self._collect(stream) self.ssc.start() start_time = time.time() @@ -89,16 +122,6 @@ def _sort_result_based_on_key(self, outputs): class TestBasicOperations(PySparkStreamingTestCase): - def test_take(self): - input = [range(i) for i in range(3)] - dstream = self.ssc.queueStream(input) - self.assertEqual([0, 0, 1], dstream._take(3)) - - def test_first(self): - input = [range(10)] - dstream = self.ssc.queueStream(input) - self.assertEqual(0, dstream._first()) - def test_map(self): """Basic operation test for DStream.map.""" input = [range(1, 5), range(5, 9), range(9, 13)] @@ -248,7 +271,7 @@ def test_repartition(self): rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): - return dstream.repartitions(1).glom() + return dstream.repartition(1).glom() expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] self._test_func(rdds, func, expected) @@ -395,15 +418,9 @@ def func(dstream): self._test_func(input, func, expected) -class TestStreamingContext(unittest.TestCase): - def setUp(self): - self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__) - self.batachDuration = 0.1 - self.ssc = StreamingContext(self.sc, self.batachDuration) +class TestStreamingContext(PySparkStreamingTestCase): - def tearDown(self): - self.ssc.stop() - self.sc.stop() + duration = 0.1 def test_stop_only_streaming_context(self): self._addInputStream() @@ -421,12 +438,12 @@ def _addInputStream(self): # Make sure each length of input is over 3 inputs = map(lambda x: range(1, x), range(5, 101)) stream = self.ssc.queueStream(inputs) - stream._collect() + self._collect(stream) def test_queueStream(self): input = [range(i) for i in range(3)] dstream = self.ssc.queueStream(input) - result = dstream._collect() + result = self._collect(dstream) self.ssc.start() time.sleep(1) self.assertEqual(input, result[:3]) @@ -445,7 +462,7 @@ def test_queueStream(self): # # self.ssc = StreamingContext(self.sc, self.batachDuration) # dstream2 = self.ssc.textFileStream(d) - # result = dstream2._collect() + # result = self._collect(dstream2) # self.ssc.start() # time.sleep(2) # self.assertEqual(input, result[:3]) @@ -455,7 +472,7 @@ def test_union(self): dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) - result = dstream3._collect() + result = self._collect(dstream3) self.ssc.start() time.sleep(1) expected = [i * 2 for i in input] @@ -472,7 +489,7 @@ def func(rdds): dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) - self.assertEqual([2, 3, 1], dstream._take(3)) + self.assertEqual([2, 3, 1], self._take(dstream, 3)) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 57791805e8f9f..4838ec6c8c6e9 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -15,6 +15,8 @@ # limitations under the License. # +from datetime import datetime + from pyspark.rdd import RDD @@ -40,7 +42,8 @@ def call(self, milliseconds, jrdds): rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD for jrdd, ser in zip(jrdds, sers)] - r = self.func(milliseconds, *rdds) + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) if r: return r._jrdd except Exception: diff --git a/python/run-tests b/python/run-tests index e8796838c22c1..e86e0729cf65e 100755 --- a/python/run-tests +++ b/python/run-tests @@ -93,9 +93,9 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -#run_core_tests -#run_sql_tests -#run_mllib_tests +run_core_tests +run_sql_tests +run_mllib_tests run_streaming_tests # Try to test with PyPy diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 30c52c15e9e68..658715eb456dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -34,7 +34,8 @@ import org.apache.spark.streaming.api.java._ /** * Interface for Python callback function with three arguments */ -trait PythonRDDFunction { +private[spark] trait PythonRDDFunction { + // callback in Python def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } @@ -44,38 +45,30 @@ trait PythonRDDFunction { private[python] class RDDFunction(pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { - def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { - if (rdd.isDefined) { - JavaRDD.fromRDD(rdd.get) - } else { - null - } - } - - def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { - if (jrdd != null) { - Some(jrdd.rdd) - } else { - None - } - } - def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava)) + PythonDStream.some(pfunc.call(time.milliseconds, List(PythonDStream.wrapRDD(rdd)).asJava)) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava)) + val rdds = List(PythonDStream.wrapRDD(rdd), PythonDStream.wrapRDD(rdd2)).asJava + PythonDStream.some(pfunc.call(time.milliseconds, rdds)) } - // for JFunction2 + // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { pfunc.call(time.milliseconds, rdds) } } + +/** + * Base class for PythonDStream with some common methods + */ private[python] -abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (parent.ssc) { +abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new RDDFunction(pfunc) override def dependencies = List(parent) @@ -84,12 +77,33 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } +/** + * Helper functions + */ private[spark] object PythonDStream { + // convert Option[RDD[_]] to JavaRDD, handle null gracefully + def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { + if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + } + + // convert JavaRDD to Option[RDD[Array[Byte]]] to , handle null gracefully + def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { + if (jrdd != null) { + Some(jrdd.rdd) + } else { + None + } + } + // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){ - val func = new RDDFunction(pyfunc) + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction){ + val func = new RDDFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } @@ -112,34 +126,36 @@ private[spark] object PythonDStream { /** * Transformed DStream in Python. * - * If the result RDD is PythonRDD, then it will cache it as an template for future use, - * this can reduce the Python callbacks. + * If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it + * as an template for future use, this can reduce the Python callbacks. */ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) - extends PythonDStream(parent) { + extends PythonDStream(parent, pfunc) { - val func = new RDDFunction(pfunc) + // rdd returned by func var lastResult: PythonRDD = _ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val rdd1 = parent.getOrCompute(validTime) - if (rdd1.isEmpty) { + val rdd = parent.getOrCompute(validTime) + if (rdd.isEmpty) { return None } if (reuse && lastResult != null) { - Some(lastResult.copyTo(rdd1.get)) + // use the previous result as the template to generate new RDD + Some(lastResult.copyTo(rdd.get)) } else { - val r = func(rdd1, validTime) + val r = func(rdd, validTime) if (reuse && r.isDefined && lastResult == null) { + // try to use the result as a template r.get match { - case rdd: PythonRDD => - if (rdd.parent(0) == rdd1) { + case pyrdd: PythonRDD => + if (pyrdd.parent(0) == rdd) { // only one PythonRDD - lastResult = rdd + lastResult = pyrdd } else { - // may have multiple stages + // maybe have multiple stages, don't check it anymore reuse = false } } @@ -174,10 +190,8 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[spark] -class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction) - extends PythonDStream(parent) { - - val reduceFunc = new RDDFunction(preduceFunc) +class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction) + extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -186,7 +200,7 @@ class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFun val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - reduceFunc(lastState, rdd, validTime) + func(lastState, rdd, validTime) } else { lastState } @@ -244,7 +258,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // add the RDDs of the reduced values in "new time steps" val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) if (newRDDs.size > 0) { - reduceFunc(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) } else { subtracted } @@ -252,7 +266,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // Get the RDDs of the reduced values in current window val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) if (currentRDDs.size > 0) { - reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime) + func(None, Some(ssc.sc.union(currentRDDs)), validTime) } else { None } From eed6e2a034646d91ddddccb42aee6809e0faa93e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 00:48:29 -0700 Subject: [PATCH 323/347] rollback not needed changes --- bin/pyspark | 6 ++- python/pyspark/accumulators.py | 5 -- python/pyspark/serializers.py | 5 -- python/pyspark/streaming/tests.py | 38 +++++++-------- python/run-tests | 81 +++++++++++++++---------------- 5 files changed, 64 insertions(+), 71 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 118e6851af7a0..5142411e36974 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -87,7 +87,11 @@ export PYSPARK_SUBMIT_ARGS if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR - exec "$PYSPARK_PYTHON" $1 + if [[ -n "$PYSPARK_DOC_TEST" ]]; then + exec "$PYSPARK_PYTHON" -m doctest $1 + else + exec "$PYSPARK_PYTHON" $1 + fi exit fi diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 9aa3db7ccf1dd..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -256,8 +256,3 @@ def _start_update_server(): thread.daemon = True thread.start() return server - - -if __name__ == "__main__": - import doctest - doctest.testmod() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e666dd9800256..94bebc310bad6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -526,8 +526,3 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) - - -if __name__ == "__main__": - import doctest - doctest.testmod() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 698978e61ffad..09d2670cc1962 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -341,6 +341,25 @@ def func(a, b): expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] self._test_func(input, func, expected, True, input2) + def update_state_by_key(self): + + def updater(it): + for k, vs, s in it: + if not s: + s = vs + else: + s.extend(vs) + yield (k, s) + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + class TestWindowFunctions(PySparkStreamingTestCase): @@ -398,25 +417,6 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) - def update_state_by_key(self): - - def updater(it): - for k, vs, s in it: - if not s: - s = vs - else: - s.extend(vs) - yield (k, s) - - input = [[('k', i)] for i in range(5)] - - def func(dstream): - return dstream.updateStateByKey(updater) - - expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] - expected = [[('k', v)] for v in expected] - self._test_func(input, func, expected) - class TestStreamingContext(PySparkStreamingTestCase): diff --git a/python/run-tests b/python/run-tests index e86e0729cf65e..c5cb580f77fd2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,39 +48,6 @@ function run_test() { fi } -function run_core_tests() { - run_test "pyspark/conf.py" - run_test "pyspark/context.py" - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/rdd.py" - run_test "pyspark/tests.py" -} - -function run_sql_tests() { - run_test "pyspark/sql.py" -} - -function run_mllib_tests() { - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/random.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/tests.py" -} - -function run_streaming_tests() { - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" -} - echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -93,10 +60,31 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_core_tests -run_sql_tests -run_mllib_tests -run_streaming_tests +run_test "pyspark/rdd.py" +run_test "pyspark/context.py" +run_test "pyspark/conf.py" +run_test "pyspark/sql.py" +# These tests are included in the module-level docs, and so must +# be handled on a higher level rather than within the python file. +export PYSPARK_DOC_TEST=1 +run_test "pyspark/broadcast.py" +run_test "pyspark/accumulators.py" +run_test "pyspark/serializers.py" +unset PYSPARK_DOC_TEST +run_test "pyspark/shuffle.py" +run_test "pyspark/tests.py" +run_test "pyspark/mllib/classification.py" +run_test "pyspark/mllib/clustering.py" +run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" +run_test "pyspark/mllib/recommendation.py" +run_test "pyspark/mllib/regression.py" +run_test "pyspark/mllib/stat.py" +run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/tree.py" +run_test "pyspark/mllib/util.py" +run_test "pyspark/streaming/util.py" +run_test "pyspark/streaming/tests.py" # Try to test with PyPy if [ $(which pypy) ]; then @@ -104,10 +92,21 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_core_tests - run_sql_tests - run_mllib_tests - run_streaming_tests + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + run_test "pyspark/sql.py" + # These tests are included in the module-level docs, and so must + # be handled on a higher level rather than within the python file. + export PYSPARK_DOC_TEST=1 + run_test "pyspark/broadcast.py" + run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + unset PYSPARK_DOC_TEST + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" fi if [[ $FAILED == 0 ]]; then From b98d63fbde10f20a42e1e6e0f34f45736b802772 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 00:52:47 -0700 Subject: [PATCH 324/347] change private[spark] to private[python] --- .../spark/streaming/api/python/PythonDStream.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 658715eb456dd..4a52ce1c4f43a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -34,7 +34,7 @@ import org.apache.spark.streaming.api.java._ /** * Interface for Python callback function with three arguments */ -private[spark] trait PythonRDDFunction { +private[python] trait PythonRDDFunction { // callback in Python def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } @@ -80,7 +80,7 @@ abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) /** * Helper functions */ -private[spark] object PythonDStream { +private[python] object PythonDStream { // convert Option[RDD[_]] to JavaRDD, handle null gracefully def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { @@ -129,7 +129,7 @@ private[spark] object PythonDStream { * If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it * as an template for future use, this can reduce the Python callbacks. */ -private[spark] +private[python] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) extends PythonDStream(parent, pfunc) { @@ -168,7 +168,7 @@ class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, /** * Transformed from two DStreams in Python. */ -private[spark] +private[python] class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], pfunc: PythonRDDFunction) extends DStream[Array[Byte]] (parent.ssc) { @@ -189,7 +189,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], /** * similar to StateDStream */ -private[spark] +private[python] class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction) extends PythonDStream(parent, reduceFunc) { @@ -210,7 +210,7 @@ class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunc /** * similar to ReducedWindowedDStream */ -private[spark] +private[python] class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction, pinvReduceFunc: PythonRDDFunction, From 9a16bd1bdce5b66ff3701aeb94b77d94e8b0a521 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 10:08:32 -0700 Subject: [PATCH 325/347] change number of partitions during tests --- python/pyspark/streaming/tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 09d2670cc1962..bd6d92255dbc6 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -22,7 +22,7 @@ import unittest import tempfile -from pyspark.context import SparkContext, RDD +from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext @@ -33,7 +33,8 @@ class PySparkStreamingTestCase(unittest.TestCase): def setUp(self): class_name = self.__class__.__name__ - self.sc = SparkContext(appName=class_name) + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) self.sc.setCheckpointDir("/tmp") # TODO: decrease duration to speed up tests self.ssc = StreamingContext(self.sc, self.duration) From 8466916cec3ce6ebba8c3c2c35f7ad4c74f90e66 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 11:51:54 -0700 Subject: [PATCH 326/347] support checkpoint --- python/pyspark/streaming/context.py | 7 +- python/pyspark/streaming/util.py | 28 +++++- .../spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 87 +++++++++++++++---- 4 files changed, 101 insertions(+), 23 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ae4a1d5b6b069..da645a6201503 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -19,11 +19,11 @@ from py4j.java_gateway import java_import from pyspark import RDD -from pyspark.serializers import UTF8Deserializer +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.util import RDDFunction +from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer __all__ = ["StreamingContext"] @@ -100,6 +100,9 @@ def _initialize_context(self, sc, duration): java_import(self._jvm, "org.apache.spark.streaming.*") java_import(self._jvm, "org.apache.spark.streaming.api.java.*") java_import(self._jvm, "org.apache.spark.streaming.api.python.*") + # register serializer for RDDFunction + ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer()) + self._jvm.PythonDStream.registerSerializer(ser) return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) def _jduration(self, seconds): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4838ec6c8c6e9..c15f9d98c1866 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -16,6 +16,7 @@ # from datetime import datetime +import traceback from pyspark.rdd import RDD @@ -47,7 +48,6 @@ def call(self, milliseconds, jrdds): if r: return r._jrdd except Exception: - import traceback traceback.print_exc() def __repr__(self): @@ -57,6 +57,32 @@ class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] +class RDDFunctionSerializer(object): + def __init__(self, ctx, serializer): + self.ctx = ctx + self.serializer = serializer + + def dumps(self, id): + try: + func = self.ctx._gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return RDDFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "RDDFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer'] + + def rddToFileName(prefix, suffix, time): """ Return string prefix-time(.suffix) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ef7631788f26d..5a8eef1372e23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -413,7 +413,7 @@ class StreamingContext private[streaming] ( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] ): DStream[T] = { - new TransformedDStream[T](dstreams, transformFunc) + new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 4a52ce1c4f43a..ddbbf107abb3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.api.python +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.api.java._ import org.apache.spark.api.python._ @@ -35,14 +36,14 @@ import org.apache.spark.streaming.api.java._ * Interface for Python callback function with three arguments */ private[python] trait PythonRDDFunction { - // callback in Python def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } /** * Wrapper for PythonRDDFunction + * TODO: support checkpoint */ -private[python] class RDDFunction(pfunc: PythonRDDFunction) +private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { @@ -58,23 +59,47 @@ private[python] class RDDFunction(pfunc: PythonRDDFunction) def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { pfunc.call(time.milliseconds, rdds) } -} + private def writeObject(out: ObjectOutputStream): Unit = { + assert(PythonDStream.serializer != null, "Serializer has not been registered!") + val bytes = PythonDStream.serializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + assert(PythonDStream.serializer != null, "Serializer has not been registered!") + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonDStream.serializer.deserialize(bytes) + } +} /** - * Base class for PythonDStream with some common methods + * Inferface for Python Serializer to serialize PythonRDDFunction */ -private[python] -abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) - extends DStream[Array[Byte]] (parent.ssc) { - - val func = new RDDFunction(pfunc) - - override def dependencies = List(parent) +private[python] trait PythonRDDFunctionSerializer { + def dumps(id: String): Array[Byte] // + def loads(bytes: Array[Byte]): PythonRDDFunction +} - override def slideDuration: Duration = parent.slideDuration +/** + * Wrapper for PythonRDDFunctionSerializer + */ +private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { + def serialize(func: PythonRDDFunction): Array[Byte] = { + // get the id of PythonRDDFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id"); + f.setAccessible(true); + val id = f.get(h).asInstanceOf[String]; + pser.dumps(id) + } - val asJavaDStream = JavaDStream.fromDStream(this) + def deserialize(bytes: Array[Byte]): PythonRDDFunction = { + pser.loads(bytes) + } } /** @@ -82,6 +107,14 @@ abstract class PythonDStream(parent: DStream[_], pfunc: PythonRDDFunction) */ private[python] object PythonDStream { + // A serializer in Python, used to serialize PythonRDDFunction + var serializer: RDDFunctionSerializer = _ + + // Register a serializer from Python, should be called during initialization + def registerSerializer(ser: PythonRDDFunctionSerializer) = { + serializer = new RDDFunctionSerializer(ser) + } + // convert Option[RDD[_]] to JavaRDD, handle null gracefully def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { if (rdd.isDefined) { @@ -123,6 +156,22 @@ private[python] object PythonDStream { } } +/** + * Base class for PythonDStream with some common methods + */ +private[python] +abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new RDDFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * Transformed DStream in Python. * @@ -130,7 +179,7 @@ private[python] object PythonDStream { * as an template for future use, this can reduce the Python callbacks. */ private[python] -class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, +class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction, var reuse: Boolean = false) extends PythonDStream(parent, pfunc) { @@ -170,7 +219,7 @@ class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, */ private[python] class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], - pfunc: PythonRDDFunction) + @transient pfunc: PythonRDDFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new RDDFunction(pfunc) @@ -190,7 +239,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[python] -class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunction) +class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -212,8 +261,8 @@ class PythonStateDStream(parent: DStream[Array[Byte]], reduceFunc: PythonRDDFunc */ private[python] class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], - preduceFunc: PythonRDDFunction, - pinvReduceFunc: PythonRDDFunction, + @transient preduceFunc: PythonRDDFunction, + @transient pinvReduceFunc: PythonRDDFunction, _windowDuration: Duration, _slideDuration: Duration ) extends PythonStateDStream(parent, preduceFunc) { From a13ff34d76c35f1a28bb09b8787715c767c9f515 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 12:25:14 -0700 Subject: [PATCH 327/347] address comments --- .../main/python/streaming/hdfs_wordcount.py | 32 +++++++++++++++-- .../python/streaming/network_wordcount.py | 30 +++++++++++++++- python/pyspark/streaming/context.py | 10 +++--- python/pyspark/streaming/dstream.py | 3 -- .../streaming/api/python/PythonDStream.scala | 36 +++---------------- 5 files changed, 69 insertions(+), 42 deletions(-) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index 8c08ff0c89850..40faff0ccc7db 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -1,3 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + import sys from pyspark import SparkContext @@ -5,10 +33,10 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: wordcount " + print >> sys.stderr, "Usage: hdfs_wordcount.py " exit(-1) - sc = SparkContext(appName="PythonStreamingWordCount") + sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) lines = ssc.textFileStream(sys.argv[1]) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index e3b6248c82a12..cfa9c1ff5bfbc 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -1,3 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + import sys from pyspark import SparkContext @@ -5,7 +33,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: wordcount " + print >> sys.stderr, "Usage: network_wordcount.py " exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index da645a6201503..9808361eb664f 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -234,11 +234,11 @@ def transform(self, dstreams, transformFunc): jdstreams = ListConverter().convert([d._jdstream for d in dstreams], SparkContext._gateway._gateway_client) # change the final serializer to sc.serializer - jfunc = RDDFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), - *[d._jrdd_deserializer for d in dstreams]) - - jdstream = self._jvm.PythonDStream.callTransform(self._jssc, jdstreams, jfunc) + func = RDDFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.RDDFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) return DStream(jdstream, self, self._sc.serializer) def union(self, *dstreams): diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 4e3f07e26953b..87d5bb4906bd5 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -150,9 +150,6 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): """ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) - # def foreach(self, func): - # return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) - def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index ddbbf107abb3e..4a19f27fe9c7d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -47,12 +47,12 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - PythonDStream.some(pfunc.call(time.milliseconds, List(PythonDStream.wrapRDD(rdd)).asJava)) + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val rdds = List(PythonDStream.wrapRDD(rdd), PythonDStream.wrapRDD(rdd2)).asJava - PythonDStream.some(pfunc.call(time.milliseconds, rdds)) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 @@ -115,39 +115,13 @@ private[python] object PythonDStream { serializer = new RDDFunctionSerializer(ser) } - // convert Option[RDD[_]] to JavaRDD, handle null gracefully - def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { - if (rdd.isDefined) { - JavaRDD.fromRDD(rdd.get) - } else { - null - } - } - - // convert JavaRDD to Option[RDD[Array[Byte]]] to , handle null gracefully - def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { - if (jrdd != null) { - Some(jrdd.rdd) - } else { - None - } - } - // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction){ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction) { val func = new RDDFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } - // helper function for ssc.transform() - def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], - pyfunc: PythonRDDFunction) - :JavaDStream[Array[Byte]] = { - val func = new RDDFunction(pyfunc) - ssc.transform(jdsteams, func) - } - // convert list of RDD into queue of RDDs, for ssc.queueStream() def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] @@ -232,7 +206,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], func(parent.getOrCompute(validTime), parent2.getOrCompute(validTime), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream = JavaDStream.fromDStream(this) } /** From fa7261b5610a02fe725f975fada995d37234f615 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 13:36:00 -0700 Subject: [PATCH 328/347] refactor --- .../apache/spark/streaming/api/python/PythonDStream.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 4a19f27fe9c7d..f2ed0c507c2b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -239,7 +239,10 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], @transient pinvReduceFunc: PythonRDDFunction, _windowDuration: Duration, _slideDuration: Duration - ) extends PythonStateDStream(parent, preduceFunc) { + ) extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true val invReduceFunc = new RDDFunction(pinvReduceFunc) From 6f0da2fa486c2a580045a2e9e3133b6617875363 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 00:08:54 -0700 Subject: [PATCH 329/347] recover from checkpoint --- .../apache/spark/api/python/PythonRDD.scala | 8 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 ++ python/pyspark/context.py | 8 +- python/pyspark/streaming/context.py | 76 ++++++++++++++----- python/pyspark/streaming/tests.py | 33 ++++++++ python/pyspark/streaming/util.py | 24 ++++-- .../streaming/api/python/PythonDStream.scala | 8 +- .../streaming/dstream/QueueInputDStream.scala | 7 ++ 9 files changed, 136 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8051b221ac3d1..b093917430a59 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -61,9 +61,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -241,7 +241,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 66c71bf7e8bb5..1069e23241302 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( private[spark] class ParallelCollectionRDD[T: ClassTag]( @transient sc: SparkContext, - @transient data: Seq[T], + data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0e90caa5c9ca7..352ce5e00d5ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -82,6 +82,14 @@ abstract class RDD[T: ClassTag]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) + // setContext after loading from checkpointing + private[spark] def setContext(s: SparkContext) = { + if (sc != null && sc != s) { + throw new SparkException("Context is already set in " + this + ", cannot set it again") + } + sc = s + } + private[spark] def conf = sc.conf // ======================================================================= // Methods that should be implemented by subclasses of RDD diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8e7b00469e246..ba930d949101d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -68,7 +68,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -103,14 +103,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -151,7 +151,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 9808361eb664f..759feda169cff 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import sys from py4j.java_collections import ListConverter from py4j.java_gateway import java_import -from pyspark import RDD +from pyspark import RDD, SparkConf from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel @@ -75,41 +77,81 @@ class StreamingContext(object): respectively. `context.awaitTransformation()` allows the current thread to wait for the termination of the context by `stop()` or by an exception. """ + _transformerSerializer = None - def __init__(self, sparkContext, duration): + def __init__(self, sparkContext, duration=None, jssc=None): """ Create a new StreamingContext. @param sparkContext: L{SparkContext} object. @param duration: number of seconds. """ + self._sc = sparkContext self._jvm = self._sc._jvm - self._start_callback_server() - self._jssc = self._initialize_context(self._sc, duration) + self._jssc = jssc or self._initialize_context(self._sc, duration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) - def _start_callback_server(self): - gw = self._sc._gateway + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + # start callback server # getattr will fallback to JVM if "_callback_server" not in gw.__dict__: _daemonize_callback_server() gw._start_callback_server(gw._python_proxy_port) - gw._python_proxy_port = gw._callback_server.port # update port with real port - def _initialize_context(self, sc, duration): - java_import(self._jvm, "org.apache.spark.streaming.*") - java_import(self._jvm, "org.apache.spark.streaming.api.java.*") - java_import(self._jvm, "org.apache.spark.streaming.api.python.*") + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # register serializer for RDDFunction - ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer()) - self._jvm.PythonDStream.registerSerializer(ser) - return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context, + CloudPickleSerializer(), gw) + gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer) - def _jduration(self, seconds): + @classmethod + def getOrCreate(cls, path, setupFunc): """ - Create Duration object given number of seconds + Get the StreamingContext from checkpoint file at `path`, or setup + it by `setupFunc`. + + :param path: directory of checkpoint + :param setupFunc: a function used to create StreamingContext and + setup DStreams. + :return: a StreamingContext """ - return self._jvm.Duration(int(seconds * 1000)) + if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path): + ssc = setupFunc() + ssc.checkpoint(path) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(path) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) @property def sparkContext(self): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index bd6d92255dbc6..00fea041d0be3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -493,5 +493,38 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) +class TestCheckpoint(PySparkStreamingTestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_get_or_create(self): + result = [0] + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, .2) + rdd = sc.parallelize(range(10), 1) + dstream = ssc.queueStream([rdd], default=rdd) + result[0] = self._collect(dstream.countByWindow(1, .2)) + return ssc + tmpd = tempfile.mkdtemp("test_streaming_cps") + ssc = StreamingContext.getOrCreate(tmpd, setup) + ssc.start() + ssc.awaitTermination(4) + ssc.stop() + expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5 + self.assertEqual(expected, result[0][:10]) + + ssc = StreamingContext.getOrCreate(tmpd, setup) + ssc.start() + ssc.awaitTermination(2) + ssc.stop() + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c15f9d98c1866..4cfaa3fc50e18 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,24 +18,31 @@ from datetime import datetime import traceback -from pyspark.rdd import RDD +from pyspark import SparkContext, RDD class RDDFunction(object): """ This class is for py4j callback. """ + _emptyRDD = None + def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - emptyRDD = getattr(self.ctx, "_emptyRDD", None) - if emptyRDD is None: - self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() - self.emptyRDD = emptyRDD + + @property + def emptyRDD(self): + if self._emptyRDD is None and self.ctx: + self._emptyRDD = self.ctx.parallelize([]).cache() + return self._emptyRDD def call(self, milliseconds, jrdds): try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + # extend deserializers with the first one sers = self.deserializers if len(sers) < len(jrdds): @@ -51,20 +58,21 @@ def call(self, milliseconds, jrdds): traceback.print_exc() def __repr__(self): - return "RDDFunction(%s)" % (str(self.func)) + return "RDDFunction(%s)" % self.func class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] class RDDFunctionSerializer(object): - def __init__(self, ctx, serializer): + def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer + self.gateway = gateway or self.ctx._gateway def dumps(self, id): try: - func = self.ctx._gateway.gateway_property.pool[id] + func = self.gateway.gateway_property.pool[id] return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index f2ed0c507c2b7..48d1f2ae17e8c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -77,7 +77,7 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) } /** - * Inferface for Python Serializer to serialize PythonRDDFunction + * Interface for Python Serializer to serialize PythonRDDFunction */ private[python] trait PythonRDDFunctionSerializer { def dumps(id: String): Array[Byte] // @@ -91,9 +91,9 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { def serialize(func: PythonRDDFunction): Array[Byte] = { // get the id of PythonRDDFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) - val f = h.getClass().getDeclaredField("id"); - f.setAccessible(true); - val id = f.get(h).asInstanceOf[String]; + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] pser.dumps(id) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315e..0557ac87b5a1e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.dstream +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import scala.collection.mutable.Queue @@ -32,6 +33,12 @@ class QueueInputDStream[T: ClassTag]( defaultRDD: RDD[T] ) extends InputDStream[T](ssc) { + private[streaming] override def setContext(s: StreamingContext) { + super.setContext(s) + queue.map(_.setContext(s.sparkContext)) + defaultRDD.setContext(s.sparkContext) + } + override def start() { } override def stop() { } From d328aca2d5396ea75c7afffc4c45987c06fc43d9 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 09:00:07 -0700 Subject: [PATCH 330/347] fix serializer in queueStream --- python/pyspark/streaming/context.py | 24 ++++++++++++++++-------- python/pyspark/streaming/dstream.py | 2 +- python/pyspark/streaming/tests.py | 6 +++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 759feda169cff..e3a34db566016 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -238,29 +238,37 @@ def textFileStream(self, directory): def _check_serialzers(self, rdds): # make sure they have same serializer - if len(set(rdd._jrdd_deserializer for rdd in rdds)): + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: for i in range(len(rdds)): # reset them to sc.serializer rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) - def queueStream(self, queue, oneAtATime=True, default=None): + def queueStream(self, rdds, oneAtATime=True, default=None): """ Create an input stream from an queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. NOTE: changes to the queue after the stream is created will not be recognized. - @param queue Queue of RDDs - @tparam T Type of objects in the RDD + + @param rdds Queue of RDDs + @param oneAtATime pick one rdd each time or pick all of them once. + @param default The default rdd if no more in rdds """ - if queue and not isinstance(queue[0], RDD): - rdds = [self._sc.parallelize(input) for input in queue] - else: - rdds = queue + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] self._check_serialzers(rdds) + jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) queue = self._jvm.PythonDStream.toRDDQueue(jrdds) if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) else: jdstream = self._jssc.queueStream(queue, oneAtATime) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 87d5bb4906bd5..8fd6c68340381 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -292,7 +292,7 @@ def transformWith(self, func, other, keepSerializer=False): oldfunc = func func = lambda t, a, b: oldfunc(a, b) assert func.func_code.co_argcount == 3, "func should take two or three arguments" - jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 00fea041d0be3..9e9a0847e7146 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -508,16 +508,16 @@ def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) ssc = StreamingContext(sc, .2) - rdd = sc.parallelize(range(10), 1) + rdd = sc.parallelize(range(1), 1) dstream = ssc.queueStream([rdd], default=rdd) - result[0] = self._collect(dstream.countByWindow(1, .2)) + result[0] = self._collect(dstream.countByWindow(1, 0.2)) return ssc tmpd = tempfile.mkdtemp("test_streaming_cps") ssc = StreamingContext.getOrCreate(tmpd, setup) ssc.start() ssc.awaitTermination(4) ssc.stop() - expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5 + expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5 self.assertEqual(expected, result[0][:10]) ssc = StreamingContext.getOrCreate(tmpd, setup) From ff88bec11c497ab62225b945546949508a5b8347 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 09:06:52 -0700 Subject: [PATCH 331/347] rename RDDFunction to TransformFunction --- python/pyspark/streaming/context.py | 16 +++--- python/pyspark/streaming/dstream.py | 16 +++--- python/pyspark/streaming/util.py | 14 ++--- .../streaming/api/python/PythonDStream.scala | 52 +++++++++---------- 4 files changed, 49 insertions(+), 49 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index e3a34db566016..0f3662b9a54a6 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -25,7 +25,7 @@ from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.util import RDDFunction, RDDFunctionSerializer +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer __all__ = ["StreamingContext"] @@ -114,10 +114,10 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # register serializer for RDDFunction + # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing - cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context, - CloudPickleSerializer(), gw) + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer) @classmethod @@ -284,10 +284,10 @@ def transform(self, dstreams, transformFunc): jdstreams = ListConverter().convert([d._jdstream for d in dstreams], SparkContext._gateway._gateway_client) # change the final serializer to sc.serializer - func = RDDFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), - *[d._jrdd_deserializer for d in dstreams]) - jfunc = self._jvm.RDDFunction(func) + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) return DStream(jdstream, self, self._sc.serializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8fd6c68340381..1b4a4421da0e0 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -22,7 +22,7 @@ from pyspark import RDD from pyspark.storagelevel import StorageLevel -from pyspark.streaming.util import rddToFileName, RDDFunction +from pyspark.streaming.util import rddToFileName, TransformFunction from pyspark.rdd import portable_hash from pyspark.resultiterable import ResultIterable @@ -154,7 +154,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) + jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -292,7 +292,7 @@ def transformWith(self, func, other, keepSerializer=False): oldfunc = func func = lambda t, a, b: oldfunc(a, b) assert func.func_code.co_argcount == 3, "func should take two or three arguments" - jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer) + jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer @@ -535,9 +535,9 @@ def invReduceFunc(t, a, b): joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) - jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jreduceFunc = TransformFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: - jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + jinvReduceFunc = TransformFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None if slideDuration is None: @@ -568,8 +568,8 @@ def reduceFunc(t, a, b): state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) - jreduceFunc = RDDFunction(self.ctx, reduceFunc, - self.ctx.serializer, self._jrdd_deserializer) + jreduceFunc = TransformFunction(self.ctx, reduceFunc, + self.ctx.serializer, self._jrdd_deserializer) dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) @@ -609,7 +609,7 @@ def _jdstream(self): return self._jdstream_val func = self.func - jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer) + jfunc = TransformFunction(self.ctx, func, self.prev._jrdd_deserializer) jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc, self.reuse).asJavaDStream() self._jdstream_val = jdstream diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4cfaa3fc50e18..4f07e44aa2d43 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -21,7 +21,7 @@ from pyspark import SparkContext, RDD -class RDDFunction(object): +class TransformFunction(object): """ This class is for py4j callback. """ @@ -58,13 +58,13 @@ def call(self, milliseconds, jrdds): traceback.print_exc() def __repr__(self): - return "RDDFunction(%s)" % self.func + return "TransformFunction(%s)" % self.func class Java: - implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] -class RDDFunctionSerializer(object): +class TransformFunctionSerializer(object): def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer @@ -80,15 +80,15 @@ def dumps(self, id): def loads(self, bytes): try: f, deserializers = self.serializer.loads(str(bytes)) - return RDDFunction(self.ctx, f, *deserializers) + return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() def __repr__(self): - return "RDDFunctionSerializer(%s)" % self.serializer + return "TransformFunctionSerializer(%s)" % self.serializer class Java: - implements = ['org.apache.spark.streaming.api.python.PythonRDDFunctionSerializer'] + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] def rddToFileName(prefix, suffix, time): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 48d1f2ae17e8c..59bb2ed5fa042 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -35,15 +35,15 @@ import org.apache.spark.streaming.api.java._ /** * Interface for Python callback function with three arguments */ -private[python] trait PythonRDDFunction { +private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } /** - * Wrapper for PythonRDDFunction + * Wrapper for PythonTransformFunction * TODO: support checkpoint */ -private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { @@ -77,19 +77,19 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) } /** - * Interface for Python Serializer to serialize PythonRDDFunction + * Interface for Python Serializer to serialize PythonTransformFunction */ -private[python] trait PythonRDDFunctionSerializer { +private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] // - def loads(bytes: Array[Byte]): PythonRDDFunction + def loads(bytes: Array[Byte]): PythonTransformFunction } /** - * Wrapper for PythonRDDFunctionSerializer + * Wrapper for PythonTransformFunctionSerializer */ -private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { - def serialize(func: PythonRDDFunction): Array[Byte] = { - // get the id of PythonRDDFunction in py4j +private[python] class TransformFunctionSerializer(pser: PythonTransformFunctionSerializer) { + def serialize(func: PythonTransformFunction): Array[Byte] = { + // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) @@ -97,7 +97,7 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { pser.dumps(id) } - def deserialize(bytes: Array[Byte]): PythonRDDFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { pser.loads(bytes) } } @@ -107,18 +107,18 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { */ private[python] object PythonDStream { - // A serializer in Python, used to serialize PythonRDDFunction - var serializer: RDDFunctionSerializer = _ + // A serializer in Python, used to serialize PythonTransformFunction + var serializer: TransformFunctionSerializer = _ // Register a serializer from Python, should be called during initialization - def registerSerializer(ser: PythonRDDFunctionSerializer) = { - serializer = new RDDFunctionSerializer(ser) + def registerSerializer(ser: PythonTransformFunctionSerializer) = { + serializer = new TransformFunctionSerializer(ser) } // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonRDDFunction) { - val func = new RDDFunction((pfunc)) + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } @@ -134,10 +134,10 @@ private[python] object PythonDStream { * Base class for PythonDStream with some common methods */ private[python] -abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunction) +abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { - val func = new RDDFunction(pfunc) + val func = new TransformFunction(pfunc) override def dependencies = List(parent) @@ -153,7 +153,7 @@ abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonRDDFunc * as an template for future use, this can reduce the Python callbacks. */ private[python] -class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDFunction, +class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonTransformFunction, var reuse: Boolean = false) extends PythonDStream(parent, pfunc) { @@ -193,10 +193,10 @@ class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonRDDF */ private[python] class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonRDDFunction) + @transient pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { - val func = new RDDFunction(pfunc) + val func = new TransformFunction(pfunc) override def slideDuration: Duration = parent.slideDuration @@ -213,7 +213,7 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], * similar to StateDStream */ private[python] -class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonRDDFunction) +class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -235,8 +235,8 @@ class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: Py */ private[python] class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], - @transient preduceFunc: PythonRDDFunction, - @transient pinvReduceFunc: PythonRDDFunction, + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration ) extends PythonDStream(parent, preduceFunc) { @@ -244,7 +244,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true - val invReduceFunc = new RDDFunction(pinvReduceFunc) + val invReduceFunc = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration override def slideDuration: Duration = _slideDuration From 7797c70f321b9ba5a66ad6a2744cf8e829dde011 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 09:09:25 -0700 Subject: [PATCH 332/347] refactor --- .../org/apache/spark/streaming/api/python/PythonDStream.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 59bb2ed5fa042..5ab15f717903e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -174,7 +174,7 @@ class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonTran // try to use the result as a template r.get match { case pyrdd: PythonRDD => - if (pyrdd.parent(0) == rdd) { + if (pyrdd.firstParent == rdd) { // only one PythonRDD lastResult = pyrdd } else { From bd8a4c2516147f1e99cf1f6e721346c18db23a20 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 09:26:26 -0700 Subject: [PATCH 333/347] fix scala style --- .../streaming/api/python/PythonDStream.scala | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 5ab15f717903e..5afcb84857350 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -47,7 +47,8 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)).map(_.rdd) + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { @@ -133,8 +134,9 @@ private[python] object PythonDStream { /** * Base class for PythonDStream with some common methods */ -private[python] -abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonTransformFunction) +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -152,9 +154,10 @@ abstract class PythonDStream(parent: DStream[_], @transient pfunc: PythonTransfo * If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it * as an template for future use, this can reduce the Python callbacks. */ -private[python] -class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonTransformFunction, - var reuse: Boolean = false) +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction, + var reuse: Boolean = false) extends PythonDStream(parent, pfunc) { // rdd returned by func @@ -191,9 +194,10 @@ class PythonTransformedDStream (parent: DStream[_], @transient pfunc: PythonTran /** * Transformed from two DStreams in Python. */ -private[python] -class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -212,8 +216,9 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], /** * similar to StateDStream */ -private[python] -class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: PythonTransformFunction) +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -233,13 +238,13 @@ class PythonStateDStream(parent: DStream[Array[Byte]], @transient reduceFunc: Py /** * similar to ReducedWindowedDStream */ -private[python] -class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, - _windowDuration: Duration, - _slideDuration: Duration - ) extends PythonDStream(parent, preduceFunc) { +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -252,8 +257,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { val currentTime = validTime - val current = new Interval(currentTime - windowDuration, - currentTime) + val current = new Interval(currentTime - windowDuration, currentTime) val previous = current - slideDuration // _____________________________ @@ -266,11 +270,10 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], // V V // old RDDs new RDDs // - val previousRDD = getOrCompute(previous.endTime) + // for small window, reduce once will be better than twice if (pinvReduceFunc != null && previousRDD.isDefined - // for small window, reduce once will be better than twice && windowDuration >= slideDuration * 5) { // subtract the values from old RDDs From 7a88f9f1b054468b40e3134d7f4e0be8aacb03fa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 12:11:40 -0700 Subject: [PATCH 334/347] rollback RDD.setContext(), use textFileStream() to test checkpointing --- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 --- python/pyspark/streaming/tests.py | 52 ++++++++++--------- .../streaming/dstream/QueueInputDStream.scala | 7 --- 4 files changed, 28 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 1069e23241302..66c71bf7e8bb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( private[spark] class ParallelCollectionRDD[T: ClassTag]( @transient sc: SparkContext, - data: Seq[T], + @transient data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 352ce5e00d5ec..0e90caa5c9ca7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -82,14 +82,6 @@ abstract class RDD[T: ClassTag]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) - // setContext after loading from checkpointing - private[spark] def setContext(s: SparkContext) = { - if (sc != null && sc != s) { - throw new SparkException("Context is already set in " + this + ", cannot set it again") - } - sc = s - } - private[spark] def conf = sc.conf // ======================================================================= // Methods that should be implemented by subclasses of RDD diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 9e9a0847e7146..b489c8b3f46f3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -70,7 +70,8 @@ def _collect(self, dstream): def get_output(_, rdd): r = rdd.collect() - result.append(r) + if r: + result.append(r) dstream.foreachRDD(get_output) return result @@ -449,24 +450,18 @@ def test_queueStream(self): time.sleep(1) self.assertEqual(input, result[:3]) - # TODO: fix this test - # def test_textFileStream(self): - # input = [range(i) for i in range(3)] - # dstream = self.ssc.queueStream(input) - # d = os.path.join(tempfile.gettempdir(), str(id(self))) - # if not os.path.exists(d): - # os.makedirs(d) - # dstream.saveAsTextFiles(os.path.join(d, 'test')) - # self.ssc.start() - # time.sleep(1) - # self.ssc.stop(False, True) - # - # self.ssc = StreamingContext(self.sc, self.batachDuration) - # dstream2 = self.ssc.textFileStream(d) - # result = self._collect(dstream2) - # self.ssc.start() - # time.sleep(2) - # self.assertEqual(input, result[:3]) + def test_textFileStream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2) + self.ssc.start() + time.sleep(1) + for name in ('a', 'b'): + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + time.sleep(2) + self.assertEqual([range(10) * 2], result[:3]) def test_union(self): input = [range(i) for i in range(3)] @@ -503,27 +498,34 @@ def tearDown(self): def test_get_or_create(self): result = [0] + inputd = tempfile.mkdtemp() def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) ssc = StreamingContext(sc, .2) - rdd = sc.parallelize(range(1), 1) - dstream = ssc.queueStream([rdd], default=rdd) - result[0] = self._collect(dstream.countByWindow(1, 0.2)) + dstream = ssc.textFileStream(inputd) + result[0] = self._collect(dstream.count()) return ssc + tmpd = tempfile.mkdtemp("test_streaming_cps") ssc = StreamingContext.getOrCreate(tmpd, setup) ssc.start() + time.sleep(1) + with open(os.path.join(inputd, "1"), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) ssc.awaitTermination(4) - ssc.stop() + ssc.stop(True, True) expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5 - self.assertEqual(expected, result[0][:10]) + self.assertEqual([[10]], result[0][:1]) ssc = StreamingContext.getOrCreate(tmpd, setup) ssc.start() + time.sleep(1) + with open(os.path.join(inputd, "1"), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) ssc.awaitTermination(2) - ssc.stop() + ssc.stop(True, True) if __name__ == "__main__": diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index 0557ac87b5a1e..ed7da6dc1315e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import scala.collection.mutable.Queue @@ -33,12 +32,6 @@ class QueueInputDStream[T: ClassTag]( defaultRDD: RDD[T] ) extends InputDStream[T](ssc) { - private[streaming] override def setContext(s: StreamingContext) { - super.setContext(s) - queue.map(_.setContext(s.sparkContext)) - defaultRDD.setContext(s.sparkContext) - } - override def start() { } override def stop() { } From 54bd92b5800ce9165e53289c44603e6a89c5ed75 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 23:34:41 -0700 Subject: [PATCH 335/347] improve tests --- python/pyspark/streaming/context.py | 1 - python/pyspark/streaming/dstream.py | 29 +++- python/pyspark/streaming/tests.py | 150 ++++++++++-------- python/pyspark/streaming/util.py | 23 +-- .../streaming/api/python/PythonDStream.scala | 56 ++++--- 5 files changed, 151 insertions(+), 108 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 0f3662b9a54a6..b84e12ebac1dc 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -118,7 +118,6 @@ def _ensure_initialized(cls): # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( SparkContext._active_spark_context, CloudPickleSerializer(), gw) - gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer) @classmethod def getOrCreate(cls, path, setupFunc): diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 1b4a4421da0e0..f8ebb7e68d8d7 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -20,6 +20,8 @@ import time from datetime import datetime +from py4j.protocol import Py4JJavaError + from pyspark import RDD from pyspark.storagelevel import StorageLevel from pyspark.streaming.util import rddToFileName, TransformFunction @@ -249,9 +251,15 @@ def saveAsTextFiles(self, prefix, suffix=None): Save each RDD in this DStream as at text file, using string representation of elements. """ - def saveAsTextFile(time, rdd): - path = rddToFileName(prefix, suffix, time) - rdd.saveAsTextFile(path) + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise return self.foreachRDD(saveAsTextFile) def _saveAsPickleFiles(self, prefix, suffix=None): @@ -259,9 +267,15 @@ def _saveAsPickleFiles(self, prefix, suffix=None): Save each RDD in this DStream as at binary file, the elements are serialized by pickle. """ - def saveAsPickleFile(time, rdd): - path = rddToFileName(prefix, suffix, time) - rdd.saveAsPickleFile(path) + def saveAsPickleFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsPickleFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise return self.foreachRDD(saveAsPickleFile) def transform(self, func): @@ -608,8 +622,7 @@ def _jdstream(self): if self._jdstream_val is not None: return self._jdstream_val - func = self.func - jfunc = TransformFunction(self.ctx, func, self.prev._jrdd_deserializer) + jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer) jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc, self.reuse).asJavaDStream() self._jdstream_val = jdstream diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index b489c8b3f46f3..ff5986776a94e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -42,6 +42,13 @@ def setUp(self): def tearDown(self): self.ssc.stop() + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + def _take(self, dstream, n): """ Return the first `n` elements in the stream (will start and stop). @@ -55,12 +62,10 @@ def take(_, rdd): dstream.foreachRDD(take) self.ssc.start() - while len(results) < n: - time.sleep(0.01) - self.ssc.stop(False, True) + self.wait_for(results, n) return results - def _collect(self, dstream): + def _collect(self, dstream, n, block=True): """ Collect each RDDs into the returned list. @@ -69,10 +74,18 @@ def _collect(self, dstream): result = [] def get_output(_, rdd): - r = rdd.collect() - if r: - result.append(r) + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) return result def _test_func(self, input, func, expected, sort=False, input2=None): @@ -94,23 +107,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None): else: stream = func(input_stream) - result = self._collect(stream) - self.ssc.start() - - start_time = time.time() - # Loop until get the expected the number of the result from the stream. - while True: - current_time = time.time() - # Check time out. - if (current_time - start_time) > self.timeout: - print "timeout after", self.timeout - break - # StreamingContext.awaitTermination is not used to wait because - # if py4j server is called every 50 milliseconds, it gets an error. - time.sleep(0.05) - # Check if the output is the same length of expected output. - if len(expected) == len(result): - break + result = self._collect(stream, len(expected)) if sort: self._sort_result_based_on_key(result) self._sort_result_based_on_key(expected) @@ -424,55 +421,50 @@ class TestStreamingContext(PySparkStreamingTestCase): duration = 0.1 + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + def test_stop_only_streaming_context(self): - self._addInputStream() + self._add_input_stream() self.ssc.start() self.ssc.stop(False) self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) def test_stop_multiple_times(self): - self._addInputStream() + self._add_input_stream() self.ssc.start() self.ssc.stop() self.ssc.stop() - def _addInputStream(self): - # Make sure each length of input is over 3 - inputs = map(lambda x: range(1, x), range(5, 101)) - stream = self.ssc.queueStream(inputs) - self._collect(stream) - - def test_queueStream(self): - input = [range(i) for i in range(3)] + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] dstream = self.ssc.queueStream(input) - result = self._collect(dstream) - self.ssc.start() - time.sleep(1) - self.assertEqual(input, result[:3]) + result = self._collect(dstream, 3) + self.assertEqual(input, result) - def test_textFileStream(self): + def test_text_file_stream(self): d = tempfile.mkdtemp() self.ssc = StreamingContext(self.sc, self.duration) dstream2 = self.ssc.textFileStream(d).map(int) - result = self._collect(dstream2) + result = self._collect(dstream2, 2, block=False) self.ssc.start() - time.sleep(1) for name in ('a', 'b'): + time.sleep(1) with open(os.path.join(d, name), "w") as f: f.writelines(["%d\n" % i for i in range(10)]) - time.sleep(2) - self.assertEqual([range(10) * 2], result[:3]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) def test_union(self): - input = [range(i) for i in range(3)] + input = [range(i + 1) for i in range(3)] dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) - result = self._collect(dstream3) - self.ssc.start() - time.sleep(1) + result = self._collect(dstream3, 3) expected = [i * 2 for i in input] - self.assertEqual(expected, result[:3]) + self.assertEqual(expected, result) def test_transform(self): dstream1 = self.ssc.queueStream([[1]]) @@ -497,34 +489,62 @@ def tearDown(self): pass def test_get_or_create(self): - result = [0] inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(it): + for k, vs, s in it: + yield (k, sum(vs, s or 0)) def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, .2) - dstream = ssc.textFileStream(inputd) - result[0] = self._collect(dstream.count()) + ssc = StreamingContext(sc, 0.2) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.2) return ssc - tmpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(tmpd, setup) + cpd = tempfile.mkdtemp("test_streaming_cps") + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() - time.sleep(1) - with open(os.path.join(inputd, "1"), 'w') as f: - f.writelines(["%d\n" % i for i in range(10)]) - ssc.awaitTermination(4) + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) ssc.stop(True, True) - expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5 - self.assertEqual([[10]], result[0][:1]) - ssc = StreamingContext.getOrCreate(tmpd, setup) - ssc.start() time.sleep(1) - with open(os.path.join(inputd, "1"), 'w') as f: - f.writelines(["%d\n" % i for i in range(10)]) - ssc.awaitTermination(2) + ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) ssc.stop(True, True) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4f07e44aa2d43..aecf7f71fdbc7 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -15,6 +15,7 @@ # limitations under the License. # +import time from datetime import datetime import traceback @@ -32,23 +33,20 @@ def __init__(self, ctx, func, *deserializers): self.func = func self.deserializers = deserializers - @property - def emptyRDD(self): - if self._emptyRDD is None and self.ctx: - self._emptyRDD = self.ctx.parallelize([]).cache() - return self._emptyRDD - def call(self, milliseconds, jrdds): try: if self.ctx is None: self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return # extend deserializers with the first one sers = self.deserializers if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -69,6 +67,7 @@ def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) def dumps(self, id): try: @@ -91,7 +90,7 @@ class Java: implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] -def rddToFileName(prefix, suffix, time): +def rddToFileName(prefix, suffix, timestamp): """ Return string prefix-time(.suffix) @@ -99,12 +98,14 @@ def rddToFileName(prefix, suffix, time): 'spark-12345678910' >>> rddToFileName("spark", "tmp", 12345678910) 'spark-12345678910.tmp' - """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 if suffix is None: - return prefix + "-" + str(time) + return prefix + "-" + str(timestamp) else: - return prefix + "-" + str(time) + "." + suffix + return prefix + "-" + str(timestamp) + "." + suffix if __name__ == "__main__": diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 5afcb84857350..59552bb0a2205 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ +import scala.language.existentials import org.apache.spark.api.java._ import org.apache.spark.api.python._ @@ -39,9 +40,16 @@ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + /** * Wrapper for PythonTransformFunction - * TODO: support checkpoint */ private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { @@ -62,44 +70,45 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun } private def writeObject(out: ObjectOutputStream): Unit = { - assert(PythonDStream.serializer != null, "Serializer has not been registered!") - val bytes = PythonDStream.serializer.serialize(pfunc) + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) out.writeInt(bytes.length) out.write(bytes) } private def readObject(in: ObjectInputStream): Unit = { - assert(PythonDStream.serializer != null, "Serializer has not been registered!") val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - pfunc = PythonDStream.serializer.deserialize(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) } } /** - * Interface for Python Serializer to serialize PythonTransformFunction + * Helpers for PythonTransformFunctionSerializer */ -private[python] trait PythonTransformFunctionSerializer { - def dumps(id: String): Array[Byte] // - def loads(bytes: Array[Byte]): PythonTransformFunction -} +private[python] object PythonTransformFunctionSerializer { + + // A serializer in Python, used to serialize PythonTransformFunction + private var serializer: PythonTransformFunctionSerializer = _ + + // Register a serializer from Python, should be called during initialization + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } -/** - * Wrapper for PythonTransformFunctionSerializer - */ -private[python] class TransformFunctionSerializer(pser: PythonTransformFunctionSerializer) { def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - pser.dumps(id) + serializer.dumps(id) } def deserialize(bytes: Array[Byte]): PythonTransformFunction = { - pser.loads(bytes) + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) } } @@ -108,12 +117,10 @@ private[python] class TransformFunctionSerializer(pser: PythonTransformFunctionS */ private[python] object PythonDStream { - // A serializer in Python, used to serialize PythonTransformFunction - var serializer: TransformFunctionSerializer = _ - - // Register a serializer from Python, should be called during initialization - def registerSerializer(ser: PythonTransformFunctionSerializer) = { - serializer = new TransformFunctionSerializer(ser) + // can not access PythonTransformFunctionSerializer.register() via Py4j + // Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) } // helper function for DStream.foreachRDD(), @@ -207,7 +214,10 @@ private[python] class PythonTransformed2DStream( override def dependencies = List(parent, parent2) override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - func(parent.getOrCompute(validTime), parent2.getOrCompute(validTime), validTime) + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) } val asJavaDStream = JavaDStream.fromDStream(this) From 4d0ea8bf5df513d5d1f4250286ca328192018f08 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 00:10:38 -0700 Subject: [PATCH 336/347] clear reference of SparkEnv after stop --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 009ed64775844..57874df3819b2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,6 +91,9 @@ class SparkEnv ( // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // clear all the references in ThreadLocal object + SparkEnv.reset() } private[spark] @@ -119,7 +122,7 @@ class SparkEnv ( } object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] + @volatile private var env = new ThreadLocal[SparkEnv] @volatile private var lastSetSparkEnv : SparkEnv = _ private[spark] val driverActorSystemName = "sparkDriver" @@ -130,6 +133,12 @@ object SparkEnv extends Logging { env.set(e) } + // clear all the threadlocal references + private[spark] def reset(): Unit = { + env = new ThreadLocal[SparkEnv] + lastSetSparkEnv = null + } + /** * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv * previously set in any thread. From c7bbbced7ba2d45e5fb2c1452920de11bd5138a8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 08:01:53 -0700 Subject: [PATCH 337/347] fix sphinx docs --- python/docs/epytext.py | 2 +- python/docs/index.rst | 1 + python/docs/modules.rst | 3 ++ python/docs/pyspark.rst | 3 +- python/pyspark/streaming/__init__.py | 2 + python/pyspark/streaming/context.py | 24 +++++----- python/pyspark/streaming/dstream.py | 65 ++++++++++++++-------------- 7 files changed, 54 insertions(+), 46 deletions(-) diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..19fefbfc057a4 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -5,7 +5,7 @@ (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index 25b3f9bd93e63..e0f4e5c192acf 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib diff --git a/python/docs/modules.rst b/python/docs/modules.rst index 183564659fbcf..04dce62be5f49 100644 --- a/python/docs/modules.rst +++ b/python/docs/modules.rst @@ -5,3 +5,6 @@ :maxdepth: 4 pyspark + pyspark.sql + pyspark.streaming + pyspark.mllib diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py index 00d2823525992..d2644a1d4ffab 100644 --- a/python/pyspark/streaming/__init__.py +++ b/python/pyspark/streaming/__init__.py @@ -17,3 +17,5 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream + +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index b84e12ebac1dc..aabbbd958080a 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -71,7 +71,7 @@ class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext represents the connection to a Spark cluster, and can be used to create - L{DStream}s various input sources. It can be from an existing L{SparkContext}. + L{DStream} various input sources. It can be from an existing L{SparkContext}. After creating and transforming DStreams, the streaming computation can be started and stopped using `context.start()` and `context.stop()`, respectively. `context.awaitTransformation()` allows the current thread @@ -180,8 +180,8 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): Stop the execution of the streams, with option of ensuring all received data has been processed. - @param stopSparkContext Stop the associated SparkContext or not - @param stopGracefully Stop gracefully by waiting for the processing + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) @@ -197,7 +197,7 @@ def remember(self, duration): the RDDs (if the developer wishes to query old data outside the DStream computation). - @param duration Minimum duration (in seconds) that each DStream + @param duration: Minimum duration (in seconds) that each DStream should remember its RDDs """ self._jssc.remember(self._jduration(duration)) @@ -207,7 +207,7 @@ def checkpoint(self, directory): Sets the context to periodically checkpoint the DStream operations for master fault-tolerance. The graph will be checkpointed every batch interval. - @param directory HDFS-compatible directory where the checkpoint data + @param directory: HDFS-compatible directory where the checkpoint data will be reliably stored """ self._jssc.checkpoint(directory) @@ -215,12 +215,12 @@ def checkpoint(self, directory): def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input from TCP source hostname:port. Data is received using - a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited lines. - @param hostname Hostname to connect to for receiving data - @param port Port to connect to for receiving data - @param storageLevel Storage level to use for storing the received objects + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects """ jlevel = self._sc._getJavaStorageLevel(storageLevel) return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, @@ -249,9 +249,9 @@ def queueStream(self, rdds, oneAtATime=True, default=None): NOTE: changes to the queue after the stream is created will not be recognized. - @param rdds Queue of RDDs - @param oneAtATime pick one rdd each time or pick all of them once. - @param default The default rdd if no more in rdds + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds """ if default and not isinstance(default, RDD): default = self._sc.parallelize(default) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index f8ebb7e68d8d7..a77e8f505e147 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -284,7 +284,7 @@ def transform(self, func): on each RDD of 'this' DStream. `func` can have one argument of `rdd`, or have two arguments of - (`time`, `rdd`) + (`time`, `rdd`) """ resue = False if func.func_code.co_argcount == 1: @@ -328,7 +328,8 @@ def _slideDuration(self): def union(self, other): """ Return a new DStream by unifying data of another DStream with this DStream. - @param other Another DStream having the same interval (i.e., slideDuration) + + @param other: Another DStream having the same interval (i.e., slideDuration) as this DStream. """ if self._slideDuration != other._slideDuration: @@ -348,11 +349,11 @@ def cogroup(self, other, numPartitions=None): def join(self, other, numPartitions=None): """ - Return a new DStream by applying 'join' between RDDs of `this` DStream and + Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - partitions. + partitions. """ if numPartitions is None: numPartitions = self.ctx.defaultParallelism @@ -360,11 +361,11 @@ def join(self, other, numPartitions=None): def leftOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - partitions. + partitions. """ if numPartitions is None: numPartitions = self.ctx.defaultParallelism @@ -372,11 +373,11 @@ def leftOuterJoin(self, other, numPartitions=None): def rightOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - partitions. + partitions. """ if numPartitions is None: numPartitions = self.ctx.defaultParallelism @@ -384,11 +385,11 @@ def rightOuterJoin(self, other, numPartitions=None): def fullOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` - partitions. + partitions. """ if numPartitions is None: numPartitions = self.ctx.defaultParallelism @@ -424,9 +425,9 @@ def window(self, windowDuration, slideDuration=None): Return a new DStream in which each RDD contains all the elements in seen in a sliding window of time over this DStream. - @param windowDuration width of the window; must be a multiple of this DStream's + @param windowDuration: width of the window; must be a multiple of this DStream's batching interval - @param slideDuration sliding interval of the window (i.e., the interval after which + @param slideDuration: sliding interval of the window (i.e., the interval after which the new DStream will generate RDDs); must be a multiple of this DStream's batching interval """ @@ -448,13 +449,13 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) This is more efficient than `invReduceFunc` is None. - @param reduceFunc associative reduce function - @param invReduceFunc inverse reduce function of `reduceFunc` - @param windowDuration width of the window; must be a multiple of this DStream's - batching interval - @param slideDuration sliding interval of the window (i.e., the interval after which - the new DStream will generate RDDs); must be a multiple of this - DStream's batching interval + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval """ keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, @@ -478,12 +479,12 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non Return a new DStream in which each RDD contains the count of distinct elements in RDDs in a sliding window over this DStream. - @param windowDuration width of the window; must be a multiple of this DStream's + @param windowDuration: width of the window; must be a multiple of this DStream's batching interval - @param slideDuration sliding interval of the window (i.e., the interval after which + @param slideDuration: sliding interval of the window (i.e., the interval after which the new DStream will generate RDDs); must be a multiple of this DStream's batching interval - @param numPartitions number of partitions of each RDD in the new DStream. + @param numPartitions: number of partitions of each RDD in the new DStream. """ keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, @@ -495,12 +496,12 @@ def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None) Return a new DStream by applying `groupByKey` over a sliding window. Similar to `DStream.groupByKey()`, but applies it over a sliding window. - @param windowDuration width of the window; must be a multiple of this DStream's + @param windowDuration: width of the window; must be a multiple of this DStream's batching interval - @param slideDuration sliding interval of the window (i.e., the interval after which + @param slideDuration: sliding interval of the window (i.e., the interval after which the new DStream will generate RDDs); must be a multiple of this DStream's batching interval - @param numPartitions Number of partitions of each RDD in the new DStream. + @param numPartitions: Number of partitions of each RDD in the new DStream. """ ls = self.mapValues(lambda x: [x]) grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], @@ -519,15 +520,15 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc associative reduce function - @param invReduceFunc inverse function of `reduceFunc` - @param windowDuration width of the window; must be a multiple of this DStream's + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's batching interval - @param slideDuration sliding interval of the window (i.e., the interval after which + @param slideDuration: sliding interval of the window (i.e., the interval after which the new DStream will generate RDDs); must be a multiple of this DStream's batching interval - @param numPartitions number of partitions of each RDD in the new DStream. - @param filterFunc function to filter expired key-value pairs; + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; only pairs that satisfy the function are retained set this to null if you do not want to filter """ @@ -567,7 +568,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None): Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. - @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). + @param updateFunc: State update function ([(k, vs, s)] -> [(k, s)]). If `s` is None, then `k` will be eliminated. """ if numPartitions is None: From be5e5ffdc5d2606042f09adb8d0fff08ddc4b85d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 08:27:52 -0700 Subject: [PATCH 338/347] merge branch of env, make tests stable. --- python/pyspark/streaming/tests.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index ff5986776a94e..6a7dfd574701d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -485,9 +485,6 @@ class TestCheckpoint(PySparkStreamingTestCase): def setUp(self): pass - def tearDown(self): - pass - def test_get_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -545,7 +542,6 @@ def check_output(n): ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) - ssc.stop(True, True) if __name__ == "__main__": From d05871e912ee4828a4ac68a6a0ceed0454e44722 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 12:46:50 -0700 Subject: [PATCH 339/347] remove reuse of PythonRDD --- python/pyspark/streaming/dstream.py | 28 ++++++------- python/pyspark/streaming/tests.py | 4 +- .../streaming/api/python/PythonDStream.scala | 39 ++++--------------- 3 files changed, 20 insertions(+), 51 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index a77e8f505e147..fddfd757b8674 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -286,13 +286,11 @@ def transform(self, func): `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) """ - resue = False if func.func_code.co_argcount == 1: - reuse = True oldfunc = func func = lambda t, rdd: oldfunc(rdd) assert func.func_code.co_argcount == 2, "func should take one or two arguments" - return TransformedDStream(self, func, reuse) + return TransformedDStream(self, func) def transformWith(self, func, other, keepSerializer=False): """ @@ -597,26 +595,23 @@ class TransformedDStream(DStream): Multiple continuous transformations of DStream can be combined into one transformation. """ - def __init__(self, prev, func, reuse=False): + def __init__(self, prev, func): ssc = prev._ssc self._ssc = ssc self.ctx = ssc._sc self._jrdd_deserializer = self.ctx.serializer self.is_cached = False self.is_checkpointed = False + self._jdstream_val = None if (isinstance(prev, TransformedDStream) and not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func - old_func = func - func = lambda t, rdd: old_func(t, prev_func(t, rdd)) - reuse = reuse and prev.reuse - prev = prev.prev - - self.prev = prev - self.func = func - self.reuse = reuse - self._jdstream_val = None + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func @property def _jdstream(self): @@ -624,7 +619,6 @@ def _jdstream(self): return self._jdstream_val jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer) - jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), - jfunc, self.reuse).asJavaDStream() - self._jdstream_val = jdstream - return jdstream + dstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 6a7dfd574701d..a839faecf9a16 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -504,7 +504,7 @@ def setup(): return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() def check_output(n): @@ -539,7 +539,7 @@ def check_output(n): ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 59552bb0a2205..96b84b45b2ebf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -157,43 +157,18 @@ private[python] abstract class PythonDStream( /** * Transformed DStream in Python. - * - * If `reuse` is true and the result of the `func` is an PythonRDD, then it will cache it - * as an template for future use, this can reduce the Python callbacks. */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction, - var reuse: Boolean = false) + @transient pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { - // rdd returned by func - var lastResult: PythonRDD = _ - override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { val rdd = parent.getOrCompute(validTime) - if (rdd.isEmpty) { - return None - } - if (reuse && lastResult != null) { - // use the previous result as the template to generate new RDD - Some(lastResult.copyTo(rdd.get)) + if (rdd.isDefined) { + func(rdd, validTime) } else { - val r = func(rdd, validTime) - if (reuse && r.isDefined && lastResult == null) { - // try to use the result as a template - r.get match { - case pyrdd: PythonRDD => - if (pyrdd.firstParent == rdd) { - // only one PythonRDD - lastResult = pyrdd - } else { - // maybe have multiple stages, don't check it anymore - reuse = false - } - } - } - r + None } } } @@ -209,10 +184,10 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def slideDuration: Duration = parent.slideDuration - override def dependencies = List(parent, parent2) + override def slideDuration: Duration = parent.slideDuration + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { val empty: RDD[_] = ssc.sparkContext.emptyRDD val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) @@ -220,7 +195,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream = JavaDStream.fromDStream(this) } /** From 37fe06fb743a1934d834d603e04b678110bc0fd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 15:43:46 -0700 Subject: [PATCH 340/347] use random port for callback server --- python/pyspark/streaming/context.py | 30 +++++++++++----- .../streaming/api/python/PythonDStream.scala | 36 +++++++++++++++---- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index aabbbd958080a..7f99d38771ce8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -18,7 +18,7 @@ import sys from py4j.java_collections import ListConverter -from py4j.java_gateway import java_import +from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer @@ -38,6 +38,8 @@ def _daemonize_callback_server(): from exiting if it's not shutdown. The following code replace `start()` of CallbackServer with a new version, which set daemon=True for this thread. + + Also, it will update the port number (0) with real port """ # TODO: create a patch for Py4J import socket @@ -54,8 +56,11 @@ def start(self): 1) try: self.server_socket.bind((self.address, self.port)) - except Exception: - msg = 'An error occurred while trying to start the callback server' + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e logger.exception(msg) raise Py4JNetworkError(msg) @@ -105,15 +110,24 @@ def _jduration(self, seconds): def _ensure_initialized(cls): SparkContext._ensure_initialized() gw = SparkContext._gateway - # start callback server - # getattr will fallback to JVM - if "_callback_server" not in gw.__dict__: - _daemonize_callback_server() - gw._start_callback_server(gw._python_proxy_port) java_import(gw.jvm, "org.apache.spark.streaming.*") java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing cls._transformerSerializer = TransformFunctionSerializer( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 96b84b45b2ebf..e171fb5730616 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,6 +24,8 @@ import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.language.existentials +import py4j.GatewayServer + import org.apache.spark.api.java._ import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD @@ -88,10 +90,14 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun */ private[python] object PythonTransformFunctionSerializer { - // A serializer in Python, used to serialize PythonTransformFunction + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ private var serializer: PythonTransformFunctionSerializer = _ - // Register a serializer from Python, should be called during initialization + /* + * Register a serializer from Python, should be called during initialization + */ def register(ser: PythonTransformFunctionSerializer): Unit = { serializer = ser } @@ -117,20 +123,36 @@ private[python] object PythonTransformFunctionSerializer { */ private[python] object PythonDStream { - // can not access PythonTransformFunctionSerializer.register() via Py4j - // Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { PythonTransformFunctionSerializer.register(ser) } - // helper function for DStream.foreachRDD(), - // cannot be `foreachRDD`, it will confusing py4j + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { val func = new TransformFunction((pfunc)) jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) } - // convert list of RDD into queue of RDDs, for ssc.queueStream() + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] rdds.forall(queue.add(_)) From e108ec114eb1a14c6e2387761da8e55bee4b3c83 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 2 Oct 2014 22:51:18 -0700 Subject: [PATCH 341/347] address comments --- .../apache/spark/api/python/PythonRDD.scala | 8 -- python/pyspark/rdd.py | 2 + python/pyspark/streaming/context.py | 38 +++--- python/pyspark/streaming/dstream.py | 112 +++++++++--------- python/pyspark/streaming/tests.py | 4 +- .../streaming/api/python/PythonDStream.scala | 2 +- 6 files changed, 80 insertions(+), 86 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fd6e3406a3b7e..f36a651dc2d8f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -52,12 +50,6 @@ private[spark] class PythonRDD( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - // create a new PythonRDD with same Python setting but different parent. - def copyTo(rdd: RDD[_]): PythonRDD = { - new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, - pythonExec, broadcastVars, accumulator) - } - val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..77e8fb1773fd1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -787,6 +787,8 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ + if not self.getNumPartitions(): + return 0 # empty RDD can not been reduced return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) def count(self): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 7f99d38771ce8..dc9dc41121935 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -84,17 +84,18 @@ class StreamingContext(object): """ _transformerSerializer = None - def __init__(self, sparkContext, duration=None, jssc=None): + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @param sparkContext: L{SparkContext} object. - @param duration: number of seconds. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches """ self._sc = sparkContext self._jvm = self._sc._jvm - self._jssc = jssc or self._initialize_context(self._sc, duration) + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) def _initialize_context(self, sc, duration): self._ensure_initialized() @@ -134,26 +135,27 @@ def _ensure_initialized(cls): SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod - def getOrCreate(cls, path, setupFunc): + def getOrCreate(cls, checkpointPath, setupFunc): """ - Get the StreamingContext from checkpoint file at `path`, or setup - it by `setupFunc`. + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. - :param path: directory of checkpoint - :param setupFunc: a function used to create StreamingContext and - setup DStreams. - :return: a StreamingContext + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams """ - if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path): + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): ssc = setupFunc() - ssc.checkpoint(path) + ssc.checkpoint(checkpointPath) return ssc cls._ensure_initialized() gw = SparkContext._gateway try: - jssc = gw.jvm.JavaStreamingContext(path) + jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: print >>sys.stderr, "failed to load StreamingContext from checkpoint" raise @@ -249,12 +251,12 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def _check_serialzers(self, rdds): + def _check_serializers(self, rdds): # make sure they have same serializer if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: for i in range(len(rdds)): # reset them to sc.serializer - rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) + rdds[i] = rdds[i]._reserialize() def queueStream(self, rdds, oneAtATime=True, default=None): """ @@ -275,7 +277,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): if rdds and not isinstance(rdds[0], RDD): rdds = [self._sc.parallelize(input) for input in rdds] - self._check_serialzers(rdds) + self._check_serializers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) @@ -313,6 +315,10 @@ def union(self, *dstreams): raise ValueError("should have at least one DStream to union") if len(dstreams) == 1: return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") first = dstreams[0] jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], SparkContext._gateway._gateway_client) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index fddfd757b8674..824131739cce3 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -53,7 +53,7 @@ class DStream(object): def __init__(self, jdstream, ssc, jrdd_deserializer): self._jdstream = jdstream self._ssc = ssc - self.ctx = ssc._sc + self._sc = ssc._sc self._jrdd_deserializer = jrdd_deserializer self.is_cached = False self.is_checkpointed = False @@ -69,13 +69,7 @@ def count(self): Return a new DStream in which each RDD has a single element generated by counting each RDD of this DStream. """ - return self.mapPartitions(lambda i: [sum(1 for _ in i)])._sum() - - def _sum(self): - """ - Add up the elements in this DStream. - """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) def filter(self, f): """ @@ -130,7 +124,7 @@ def reduceByKey(self, func, numPartitions=None): Return a new DStream by applying reduceByKey to each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.combineByKey(lambda x: x, func, func, numPartitions) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, @@ -139,7 +133,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, Return a new DStream by applying combineByKey to each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism def func(rdd): return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) @@ -156,7 +150,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -216,7 +210,7 @@ def persist(self, storageLevel): Persist the RDDs of this DStream with the given storage level """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdstream.persist(javaStorageLevel) return self @@ -236,7 +230,7 @@ def groupByKey(self, numPartitions=None): Return a new DStream by applying groupByKey on each RDD. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) def countByValue(self): @@ -262,21 +256,22 @@ def saveAsTextFile(t, rdd): raise return self.foreachRDD(saveAsTextFile) - def _saveAsPickleFiles(self, prefix, suffix=None): - """ - Save each RDD in this DStream as at binary file, the elements are - serialized by pickle. - """ - def saveAsPickleFile(t, rdd): - path = rddToFileName(prefix, suffix, t) - try: - rdd.saveAsPickleFile(path) - except Py4JJavaError as e: - # after recovered from checkpointing, the foreachRDD may - # be called twice - if 'FileAlreadyExistsException' not in str(e): - raise - return self.foreachRDD(saveAsPickleFile) + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) def transform(self, func): """ @@ -304,10 +299,10 @@ def transformWith(self, func, other, keepSerializer=False): oldfunc = func func = lambda t, a, b: oldfunc(a, b) assert func.func_code.co_argcount == 3, "func should take two or three arguments" - jfunc = TransformFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer) - dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) - jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) def repartition(self, numPartitions): @@ -336,61 +331,61 @@ def union(self, other): def cogroup(self, other, numPartitions=None): """ - Return a new DStream by applying 'cogroup' between RDDs of `this` + Return a new DStream by applying 'cogroup' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) def join(self, other, numPartitions=None): """ - Return a new DStream by applying 'join' between RDDs of `this` DStream and + Return a new DStream by applying 'join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.join(b, numPartitions), other) def leftOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'left outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) def rightOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'right outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) def fullOuterJoin(self, other, numPartitions=None): """ - Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + Return a new DStream by applying 'full outer join' between RDDs of this DStream and `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` partitions. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) def _jtime(self, timestamp): @@ -398,7 +393,7 @@ def _jtime(self, timestamp): """ if isinstance(timestamp, datetime): timestamp = time.mktime(timestamp.timetuple()) - return self.ctx._jvm.Time(long(timestamp * 1000)) + return self._sc._jvm.Time(long(timestamp * 1000)) def slice(self, begin, end): """ @@ -407,7 +402,7 @@ def slice(self, begin, end): `begin`, `end` could be datetime.datetime() or unix_timestamp """ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) - return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds] + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] def _validate_window_param(self, window, slide): duration = self._jdstream.dstream().slideDuration().milliseconds() @@ -532,7 +527,7 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None """ self._validate_window_param(windowDuration, slideDuration) if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism reduced = self.reduceByKey(func, numPartitions) @@ -548,18 +543,18 @@ def invReduceFunc(t, a, b): joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) - jreduceFunc = TransformFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: - jinvReduceFunc = TransformFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer) + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None if slideDuration is None: slideDuration = self._slideDuration - dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, self._ssc._jduration(windowDuration), self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) def updateStateByKey(self, updateFunc, numPartitions=None): """ @@ -570,7 +565,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None): If `s` is None, then `k` will be eliminated. """ if numPartitions is None: - numPartitions = self.ctx.defaultParallelism + numPartitions = self._sc.defaultParallelism def reduceFunc(t, a, b): if a is None: @@ -581,10 +576,10 @@ def reduceFunc(t, a, b): state = g.mapPartitions(lambda x: updateFunc(x)) return state.filter(lambda (k, v): v is not None) - jreduceFunc = TransformFunction(self.ctx, reduceFunc, - self.ctx.serializer, self._jrdd_deserializer) - dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) - return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) class TransformedDStream(DStream): @@ -596,10 +591,9 @@ class TransformedDStream(DStream): one transformation. """ def __init__(self, prev, func): - ssc = prev._ssc - self._ssc = ssc - self.ctx = ssc._sc - self._jrdd_deserializer = self.ctx.serializer + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer self.is_cached = False self.is_checkpointed = False self._jdstream_val = None @@ -618,7 +612,7 @@ def _jdstream(self): if self._jdstream_val is not None: return self._jdstream_val - jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer) - dstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a839faecf9a16..9f5cdff5ed809 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -496,11 +496,11 @@ def updater(it): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.2) + ssc = StreamingContext(sc, 0.5) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.2) + wc.checkpoint(.5) return ssc cpd = tempfile.mkdtemp("test_streaming_cps") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index e171fb5730616..696dfb969a48a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -36,7 +36,7 @@ import org.apache.spark.streaming.api.java._ /** - * Interface for Python callback function with three arguments + * Interface for Python callback function which is used to transform RDDs */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] From 52c535b0696b3861222a7bd6608bb3f6f4db64c3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 3 Oct 2014 08:54:33 -0700 Subject: [PATCH 342/347] remove fix for sum() --- python/pyspark/rdd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 77e8fb1773fd1..dc6497772e502 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -787,8 +787,6 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ - if not self.getNumPartitions(): - return 0 # empty RDD can not been reduced return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) def count(self): From bebeb4aa6df42b6a72ffa7afb574891d2ce46c59 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 7 Oct 2014 16:44:57 -0700 Subject: [PATCH 343/347] address all comments --- python/pyspark/streaming/dstream.py | 8 ++++---- python/pyspark/streaming/util.py | 17 ++++++++++++++++- .../streaming/api/python/PythonDStream.scala | 11 ++++++++--- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 824131739cce3..4533c5d541a51 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -174,7 +174,7 @@ def takeAndPrint(time, rdd): def mapValues(self, f): """ Return a new DStream by applying a map function to the value of - each key-value pairs in 'this' DStream without changing the key. + each key-value pairs in this DStream without changing the key. """ map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) @@ -182,7 +182,7 @@ def mapValues(self, f): def flatMapValues(self, f): """ Return a new DStream by applying a flatmap function to the value - of each key-value pairs in 'this' DStream without changing the key. + of each key-value pairs in this DStream without changing the key. """ flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) return self.flatMap(flat_map_fn, preservesPartitioning=True) @@ -276,7 +276,7 @@ def saveAsTextFile(t, rdd): def transform(self, func): """ Return a new DStream in which each RDD is generated by applying a function - on each RDD of 'this' DStream. + on each RDD of this DStream. `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) @@ -290,7 +290,7 @@ def transform(self, func): def transformWith(self, func, other, keepSerializer=False): """ Return a new DStream in which each RDD is generated by applying a function - on each RDD of 'this' DStream and 'other' DStream. + on each RDD of this DStream and 'other' DStream. `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three arguments of (`time`, `rdd_a`, `rdd_b`) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index aecf7f71fdbc7..86ee5aa04f252 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -24,7 +24,12 @@ class TransformFunction(object): """ - This class is for py4j callback. + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. """ _emptyRDD = None @@ -63,6 +68,16 @@ class Java: class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 696dfb969a48a..213dff6a76354 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -51,10 +51,12 @@ private[python] trait PythonTransformFunctionSerializer { } /** - * Wrapper for PythonTransformFunction + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. */ private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) - extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) @@ -87,6 +89,9 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun /** * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. */ private[python] object PythonTransformFunctionSerializer { @@ -119,7 +124,7 @@ private[python] object PythonTransformFunctionSerializer { } /** - * Helper functions + * Helper functions, which are called from Python via Py4J. */ private[python] object PythonDStream { From 02d05751ea281d377ce52ad39ccd30e518d2ff5a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 14:27:24 -0700 Subject: [PATCH 344/347] add wrapper for foreachRDD() --- python/pyspark/streaming/dstream.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 4533c5d541a51..5d0dface2f043 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -150,6 +150,9 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) From 3e2492b9b95e0cc0e3427265f71f069000cc43f7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 15:02:29 -0700 Subject: [PATCH 345/347] change updateStateByKey() to easy API --- .../streaming/stateful_network_wordcount.py | 57 +++++++++++++++++++ python/pyspark/streaming/dstream.py | 10 ++-- python/pyspark/streaming/tests.py | 22 ++++--- 3 files changed, 72 insertions(+), 17 deletions(-) create mode 100644 examples/src/main/python/streaming/stateful_network_wordcount.py diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..7bd1512180920 --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 5d0dface2f043..5ae5cf07f0137 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -564,19 +564,19 @@ def updateStateByKey(self, updateFunc, numPartitions=None): Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. - @param updateFunc: State update function ([(k, vs, s)] -> [(k, s)]). - If `s` is None, then `k` will be eliminated. + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. """ if numPartitions is None: numPartitions = self._sc.defaultParallelism def reduceFunc(t, a, b): if a is None: - g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: g = a.cogroup(b, numPartitions) - g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None)) - state = g.mapPartitions(lambda x: updateFunc(x)) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) return state.filter(lambda (k, v): v is not None) jreduceFunc = TransformFunction(self._sc, reduceFunc, diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 9f5cdff5ed809..0e5c1a3b3c2ad 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -119,7 +119,7 @@ def _sort_result_based_on_key(self, outputs): output.sort(key=lambda x: x[0]) -class TestBasicOperations(PySparkStreamingTestCase): +class BasicOperationTests(PySparkStreamingTestCase): def test_map(self): """Basic operation test for DStream.map.""" @@ -340,15 +340,13 @@ def func(a, b): expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] self._test_func(input, func, expected, True, input2) - def update_state_by_key(self): + def test_update_state_by_key(self): - def updater(it): - for k, vs, s in it: - if not s: - s = vs - else: - s.extend(vs) - yield (k, s) + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s input = [[('k', i)] for i in range(5)] @@ -360,7 +358,7 @@ def func(dstream): self._test_func(input, func, expected) -class TestWindowFunctions(PySparkStreamingTestCase): +class WindowFunctionTests(PySparkStreamingTestCase): timeout = 20 @@ -417,7 +415,7 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) -class TestStreamingContext(PySparkStreamingTestCase): +class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 @@ -480,7 +478,7 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) -class TestCheckpoint(PySparkStreamingTestCase): +class CheckpointTests(PySparkStreamingTestCase): def setUp(self): pass From 331ecced6f61ad5183da5830f94f584bcc74e479 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 22:25:09 -0700 Subject: [PATCH 346/347] fix example --- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 7bd1512180920..18a9a5a452ffb 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -39,7 +39,7 @@ if len(sys.argv) != 3: print >> sys.stderr, "Usage: stateful_network_wordcount.py " exit(-1) - sc = SparkContext(appName="PythonStreamingNetworkWordCount") + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") From 64561e4e503eafb958f6769383ba3b37edbe5fa2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 22:47:46 -0700 Subject: [PATCH 347/347] fix tests --- python/pyspark/streaming/tests.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0e5c1a3b3c2ad..a8d876d0fa3b3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -487,9 +487,8 @@ def test_get_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" - def updater(it): - for k, vs, s in it: - yield (k, sum(vs, s or 0)) + def updater(vs, s): + return sum(vs, s or 0) def setup(): conf = SparkConf().set("spark.default.parallelism", 1)