diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc747..ea293287a8314 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -15,77 +15,6 @@ # limitations under the License. # -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> a = sc.accumulator(1) ->>> a.value -1 ->>> a.value = 2 ->>> a.value -2 ->>> a += 5 ->>> a.value -7 - ->>> sc.accumulator(1.0).value -1.0 - ->>> sc.accumulator(1j).value -1j - ->>> rdd = sc.parallelize([1,2,3]) ->>> def f(x): -... global a -... a += x ->>> rdd.foreach(f) ->>> a.value -13 - ->>> b = sc.accumulator(0) ->>> def g(x): -... b.add(x) ->>> rdd.foreach(g) ->>> b.value -6 - ->>> from pyspark.accumulators import AccumulatorParam ->>> class VectorAccumulatorParam(AccumulatorParam): -... def zero(self, value): -... return [0.0] * len(value) -... def addInPlace(self, val1, val2): -... for i in range(len(val1)): -... val1[i] += val2[i] -... return val1 ->>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) ->>> va.value -[1.0, 2.0, 3.0] ->>> def g(x): -... global va -... va += [x] * 3 ->>> rdd.foreach(g) ->>> va.value -[7.0, 8.0, 9.0] - ->>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL -Traceback (most recent call last): - ... -Py4JJavaError:... - ->>> def h(x): -... global a -... a.value = 7 ->>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL -Traceback (most recent call last): - ... -Py4JJavaError:... - ->>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL -Traceback (most recent call last): - ... -TypeError:... -""" - import sys import select import struct @@ -117,6 +46,76 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ + >>> from pyspark.context import SparkContext + >>> sc = SparkContext('local', 'test') + >>> a = sc.accumulator(1) + >>> a.value + 1 + >>> a.value = 2 + >>> a.value + 2 + >>> a += 5 + >>> a.value + 7 + + >>> sc.accumulator(1.0).value + 1.0 + + >>> sc.accumulator(1j).value + 1j + + >>> rdd = sc.parallelize([1,2,3]) + >>> def f(x): + ... global a + ... a += x + >>> rdd.foreach(f) + >>> a.value + 13 + + >>> b = sc.accumulator(0) + >>> def g(x): + ... b.add(x) + >>> rdd.foreach(g) + >>> b.value + 6 + + >>> from pyspark.accumulators import AccumulatorParam + >>> class VectorAccumulatorParam(AccumulatorParam): + ... def zero(self, value): + ... return [0.0] * len(value) + ... def addInPlace(self, val1, val2): + ... for i in range(len(val1)): + ... val1[i] += val2[i] + ... return val1 + >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) + >>> va.value + [1.0, 2.0, 3.0] + >>> def g(x): + ... global va + ... va += [x] * 3 + >>> rdd.foreach(g) + >>> va.value + [7.0, 8.0, 9.0] + + >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError:... + + >>> def h(x): + ... global a + ... a.value = 7 + >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError:... + + >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + """ """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" @@ -263,7 +262,7 @@ def _start_update_server(): return server if __name__ == "__main__": - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe0881e..ab30a3509e4da 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -114,7 +114,7 @@ def __reduce__(self): if __name__ == "__main__": - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 924da3eecf214..067e00b4bdb65 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -15,45 +15,6 @@ # limitations under the License. # -""" ->>> from pyspark.conf import SparkConf ->>> from pyspark.context import SparkContext ->>> conf = SparkConf() ->>> conf.setMaster("local").setAppName("My app") - ->>> conf.get("spark.master") -u'local' ->>> conf.get("spark.app.name") -u'My app' ->>> sc = SparkContext(conf=conf) ->>> sc.master -u'local' ->>> sc.appName -u'My app' ->>> sc.sparkHome is None -True - ->>> conf = SparkConf(loadDefaults=False) ->>> conf.setSparkHome("/path") - ->>> conf.get("spark.home") -u'/path' ->>> conf.setExecutorEnv("VAR1", "value1") - ->>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")]) - ->>> conf.get("spark.executorEnv.VAR1") -u'value1' ->>> print(conf.toDebugString()) -spark.executorEnv.VAR1=value1 -spark.executorEnv.VAR3=value3 -spark.executorEnv.VAR4=value4 -spark.home=/path ->>> sorted(conf.getAll(), key=lambda p: p[0]) -[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ -(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] -""" - __all__ = ['SparkConf'] import sys @@ -86,6 +47,45 @@ class SparkConf(object): and can no longer be modified by the user. """ + """ + >>> from pyspark.conf import SparkConf + >>> from pyspark.context import SparkContext + >>> conf = SparkConf() + >>> conf.setMaster("local").setAppName("My app") + + >>> conf.get("spark.master") + u'local' + >>> conf.get("spark.app.name") + u'My app' + >>> sc = SparkContext(conf=conf) + >>> sc.master + u'local' + >>> sc.appName + u'My app' + >>> sc.sparkHome is None + True + + >>> conf = SparkConf(loadDefaults=False) + >>> conf.setSparkHome("/path") + + >>> conf.get("spark.home") + u'/path' + >>> conf.setExecutorEnv("VAR1", "value1") + + >>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")]) + + >>> conf.get("spark.executorEnv.VAR1") + u'value1' + >>> print(conf.toDebugString()) + spark.executorEnv.VAR1=value1 + spark.executorEnv.VAR3=value3 + spark.executorEnv.VAR4=value4 + spark.home=/path + >>> sorted(conf.getAll(), key=lambda p: p[0]) + [(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ + (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] + """ + def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. @@ -182,8 +182,9 @@ def toDebugString(self): def _test(): import doctest - (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__, optionflags=doctest.ELLIPSIS) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b480399..4a378c9e85b94 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -951,15 +951,17 @@ def dump_profiles(self, path): def _test(): import atexit + from pyspark.doctesthelper import run_doctests import doctest import tempfile globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/doctesthelper.py b/python/pyspark/doctesthelper.py new file mode 100644 index 0000000000000..8947d617da43f --- /dev/null +++ b/python/pyspark/doctesthelper.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest +import doctest +try: + import xmlrunner +except ImportError: + xmlrunner = None + + +def run_doctests(file_name, globs={}, optionflags=0): + t = doctest.DocFileSuite(file_name, module_relative=False, + globs=globs, optionflags=optionflags) + if xmlrunner: + return xmlrunner.XMLTestRunner(output='target/test-reports', + verbosity=3).run(t) + else: + return unittest.TextTestRunner(verbosity=3).run(t) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 16ad76483de63..94228b2abd0e6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import warnings +import sys from pyspark import since from pyspark.ml.util import * @@ -878,8 +879,10 @@ def weights(self): if __name__ == "__main__": import doctest import pyspark.ml.classification + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext + import tempfile globs = pyspark.ml.classification.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -887,11 +890,11 @@ def weights(self): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() finally: from shutil import rmtree @@ -899,5 +902,5 @@ def weights(self): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 1cea477acb47d..5ca8c8461ce61 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # + from pyspark import since from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -295,8 +296,10 @@ def _create_model(self, java_model): if __name__ == "__main__": import doctest import pyspark.ml.clustering + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext + import tempfile globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -304,11 +307,11 @@ def _create_model(self, java_model): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() finally: from shutil import rmtree @@ -316,5 +319,5 @@ def _create_model(self, java_model): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf45d9..bcbbbf12ac512 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -306,6 +306,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", if __name__ == "__main__": import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() @@ -315,8 +316,8 @@ def setParams(self, predictionCol="prediction", labelCol="label", sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5025493c42c38..ccc3eedb1e0f9 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2560,6 +2560,7 @@ def selectedFeatures(self): import tempfile import pyspark.ml.feature + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext @@ -2580,7 +2581,8 @@ def selectedFeatures(self): temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() finally: from shutil import rmtree @@ -2588,5 +2590,5 @@ def selectedFeatures(self): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 2b605e5c5078b..6296c12830079 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # + from pyspark import since from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -324,8 +325,10 @@ def itemFactors(self): if __name__ == "__main__": import doctest import pyspark.ml.recommendation + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext + import tempfile globs = pyspark.ml.recommendation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -333,11 +336,11 @@ def itemFactors(self): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() finally: from shutil import rmtree @@ -345,5 +348,5 @@ def itemFactors(self): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6e23393f9102f..d01e18120f912 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -904,8 +904,10 @@ def predict(self, features): if __name__ == "__main__": import doctest import pyspark.ml.regression + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext + import tempfile globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -913,11 +915,11 @@ def predict(self, features): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() finally: from shutil import rmtree @@ -925,5 +927,5 @@ def predict(self, features): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 77af0094dfca4..d3aa1ac522583 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -481,6 +481,7 @@ def copy(self, extra=None): if __name__ == "__main__": import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() @@ -490,8 +491,8 @@ def copy(self, extra=None): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) sc.stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 57106f8690a7d..2f6431f9f0707 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -749,13 +749,15 @@ def update(rdd): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext import pyspark.mllib.classification globs = pyspark.mllib.classification.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 23d118bd40900..c70ad517c57b5 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1060,12 +1060,14 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + from pyspark.doctesthelper import run_doctests import pyspark.mllib.clustering globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 22e68ea5b4511..6a5be9d23da13 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # + from pyspark import since from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.sql import SQLContext @@ -516,13 +517,15 @@ def accuracy(self): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext import pyspark.mllib.evaluation globs = pyspark.mllib.evaluation.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 612935352575f..42fd1ec5d5f08 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -712,12 +712,14 @@ def transform(self, vector): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f339e50891166..58e1b69b825cc 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -183,15 +183,17 @@ class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])): def _test(): import doctest + from pyspark.doctesthelper import run_doctests import pyspark.mllib.fpm + import tempfile globs = pyspark.mllib.fpm.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - import tempfile - temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path + try: - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() finally: from shutil import rmtree @@ -199,7 +201,7 @@ def _test(): rmtree(temp_path) except OSError: pass - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index abf00a4737948..e88910aac430d 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1236,9 +1236,9 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): - import doctest - (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__, optionflags=doctest.ELLIPSIS) + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 43cb0beef1bd3..abbb32d486c0b 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -918,6 +918,7 @@ def toCoordinateMatrix(self): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext from pyspark.sql import SQLContext from pyspark.mllib.linalg import Matrices @@ -926,9 +927,10 @@ def _test(): globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) globs['sqlContext'] = SQLContext(globs['sc']) globs['Matrices'] = Matrices - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 6a3c643b66417..efaca9e6e3873 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -409,14 +409,16 @@ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed= def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 7e60255d43ead..5ca1b5d2978b3 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -317,15 +317,17 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp def _test(): import doctest + from pyspark.doctesthelper import run_doctests import pyspark.mllib.recommendation from pyspark.sql import SQLContext globs = pyspark.mllib.recommendation.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 3b77a6200054f..d7f4563368199 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -806,13 +806,15 @@ def update(rdd): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 7da921976d4d2..4564c5c9ed744 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -37,7 +37,7 @@ class KernelDensity(object): >>> sample = sc.parallelize([0.0, 1.0]) >>> kd.setSample(sample) >>> kd.estimate([0.0, 1.0]) - array([ 0.12938758, 0.12938758]) + array([ 0.3204565, 0.3204565]) """ def __init__(self): self._bandwidth = 1.0 @@ -59,3 +59,20 @@ def estimate(self, points): densities = callMLlibFunc( "estimateKernelDensity", self._sample, self._bandwidth, points) return np.asarray(densities) + + +def _test(): + import doctest + from pyspark.doctesthelper import run_doctests + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if not result.wasSuccessful(): + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 36c8f48a4a882..69ff179dae240 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -306,12 +306,14 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark import SparkContext globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py index 46f7a1d2f277a..d642badb22b27 100644 --- a/python/pyspark/mllib/stat/distribution.py +++ b/python/pyspark/mllib/stat/distribution.py @@ -30,3 +30,20 @@ class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])): >>> (m[0], m[1]) (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]])) """ + + +def _test(): + import doctest + from pyspark.doctesthelper import run_doctests + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if not result.wasSuccessful(): + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f7ea466b43291..ce47c22afdb17 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -656,11 +656,13 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, def _test(): import doctest + from pyspark.doctesthelper import run_doctests globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 39bc6586dd582..cc26421d8a10a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -347,14 +347,16 @@ def generateLinearRDD(sc, nexamples, nfeatures, eps, def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd629473..7ad929715089a 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -170,7 +170,7 @@ def stats(self): if __name__ == "__main__": - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 37574cea0b687..bb5185a950327 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2004,20 +2004,20 @@ def keyBy(self, f): def repartition(self, numPartitions): """ - Return a new RDD that has exactly numPartitions partitions. + Return a new RDD that has exactly numPartitions partitions. - Can increase or decrease the level of parallelism in this RDD. - Internally, this uses a shuffle to redistribute data. - If you are decreasing the number of partitions in this RDD, consider - using `coalesce`, which can avoid performing a shuffle. + Can increase or decrease the level of parallelism in this RDD. + Internally, this uses a shuffle to redistribute data. + If you are decreasing the number of partitions in this RDD, consider + using `coalesce`, which can avoid performing a shuffle. - >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) - >>> sorted(rdd.glom().collect()) - [[1], [2, 3], [4, 5], [6, 7]] - >>> len(rdd.repartition(2).glom().collect()) - 2 - >>> len(rdd.repartition(10).glom().collect()) - 10 + >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) + >>> sorted(rdd.glom().collect()) + [[1], [2, 3], [4, 5], [6, 7]] + >>> len(rdd.repartition(2).glom().collect()) + 2 + >>> len(rdd.repartition(10).glom().collect()) + 10 """ jrdd = self._jrdd.repartition(numPartitions) return RDD(jrdd, self.ctx, self._jrdd_deserializer) @@ -2420,16 +2420,17 @@ def _is_pipelinable(self): def _test(): + from pyspark.doctesthelper import run_doctests import doctest from pyspark.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a1326947f4f5..e79faf48f0ce5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,40 +15,6 @@ # limitations under the License. # -""" -PySpark supports custom serializers for transferring data; this can improve -performance. - -By default, PySpark uses L{PickleSerializer} to serialize objects using Python's -C{cPickle} serializer, which can serialize nearly any Python object. -Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be -faster. - -The serializer is chosen when creating L{SparkContext}: - ->>> from pyspark.context import SparkContext ->>> from pyspark.serializers import MarshalSerializer ->>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) ->>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) -[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] ->>> sc.stop() - -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: - ->>> sc = SparkContext('local', 'test', batchSize=2) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) - -Behind the scenes, this creates a JavaRDD with four partitions, each of -which contains two batches of two objects: - ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> int(rdd._jrdd.count()) -8 ->>> sc.stop() -""" - import sys from itertools import chain, product import marshal @@ -438,6 +404,40 @@ class MarshalSerializer(FramedSerializer): This serializer is faster than PickleSerializer but supports fewer datatypes. """ + """ + PySpark supports custom serializers for transferring data; this can improve + performance. + + By default, PySpark uses L{PickleSerializer} to serialize objects using Python's + C{cPickle} serializer, which can serialize nearly any Python object. + Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be + faster. + + The serializer is chosen when creating L{SparkContext}: + + >>> from pyspark.context import SparkContext + >>> from pyspark.serializers import MarshalSerializer + >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.stop() + + PySpark serialize objects in batches; By default, the batch size is chosen based + on the size of objects, also configurable by SparkContext's C{batchSize} parameter: + + >>> sc = SparkContext('local', 'test', batchSize=2) + >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + + Behind the scenes, this creates a JavaRDD with four partitions, each of + which contains two batches of two objects: + + >>> rdd.glom().collect() + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + >>> int(rdd._jrdd.count()) + 8 + >>> sc.stop() + """ + def dumps(self, obj): return marshal.dumps(obj) @@ -556,7 +556,7 @@ def write_with_length(obj, stream): if __name__ == '__main__': - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e1..93c33f759e08f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -807,7 +807,7 @@ def load_partition(j): if __name__ == "__main__": - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 19ec6fcc5d6dc..92dbe0ea6913e 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -436,6 +436,7 @@ def __repr__(self): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext import pyspark.sql.column @@ -447,11 +448,11 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) - (failure_count, test_count) = doctest.testmod( - pyspark.sql.column, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9c2f6a3c5660f..3a4864eb9bc5e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -645,6 +645,7 @@ def register(self, name, f, returnType=StringType()): def _test(): import os import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context @@ -670,11 +671,10 @@ def _test(): ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) - (failure_count, test_count) = doctest.testmod( - pyspark.sql.context, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7e1854c43be3b..32580dde5638e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1449,6 +1449,7 @@ def sampleBy(self, col, fractions, seed=None): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.dataframe @@ -1467,11 +1468,11 @@ def _test(): Row(name='Tom', age=None, height=None), Row(name=None, age=None, height=None)]).toDF() - (failure_count, test_count) = doctest.testmod( - pyspark.sql.dataframe, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dee3d536be432..d3b08ea8df87f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1685,6 +1685,7 @@ def udf(f, returnType=StringType()): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.functions @@ -1693,11 +1694,10 @@ def _test(): globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() - (failure_count, test_count) = doctest.testmod( - pyspark.sql.functions, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb439287..ad6a1fce4375a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,7 @@ # limitations under the License. # + from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -195,6 +196,7 @@ def pivot(self, pivot_col, values=None): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.group @@ -213,11 +215,11 @@ def _test(): Row(course="dotNET", year=2013, earnings=48000), Row(course="Java", year=2013, earnings=30000)]).toDF() - (failure_count, test_count) = doctest.testmod( - pyspark.sql.group, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 438662bb157f0..213cb24afa331 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -599,6 +599,7 @@ def _test(): import doctest import os import tempfile + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, HiveContext import pyspark.sql.readwriter @@ -615,11 +616,11 @@ def _test(): globs['hiveContext'] = HiveContext(sc) globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') - (failure_count, test_count) = doctest.testmod( - pyspark.sql.readwriter, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | + doctest.REPORT_NDIFF) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 734c1533a24bc..0701323b1c6a5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1477,15 +1477,17 @@ def convert(self, obj, gateway_client): def _test(): import doctest + from pyspark.doctesthelper import run_doctests from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + result = run_doctests(__file__, globs=globs, + optionflags=doctest.ELLIPSIS) globs['sc'].stop() - if failure_count: + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 57bbe340bbd4d..5e103a57e019c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -15,7 +15,6 @@ # limitations under the License. # -import sys from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column @@ -145,10 +144,11 @@ def rangeBetween(self, start, end): def _test(): - import doctest - SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + sc = SparkContext('local[4]', 'PythonTest') + result = run_doctests(__file__) + sc.stop() + if not result.wasSuccessful(): exit(-1) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 03ea0b6d33c9d..770a96cc85e0f 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -156,3 +156,10 @@ def asDict(self, sample=False): def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min())) + + +if __name__ == "__main__": + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): + exit(-1) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..18ead72ee7208 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -144,7 +144,7 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": - import doctest - (failure_count, test_count) = doctest.testmod() - if failure_count: + from pyspark.doctesthelper import run_doctests + result = run_doctests(__file__) + if not result.wasSuccessful(): exit(-1)