From 5b4e37cd8487f9a4c196bd7c85cef6b5def07b45 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 9 Sep 2014 11:58:26 -0400 Subject: [PATCH] [SPARK-3458] enable python "with" statements for SparkContext allow for best practice code, try: sc = SparkContext() app(sc) finally: sc.stop() to be written using a "with" statement, with SparkContext() as sc: app(sc) --- python/pyspark/context.py | 14 ++++++++++++++ python/pyspark/tests.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5a30431568b16..84bc0a3b7ccd0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -232,6 +232,20 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __enter__(self): + """ + Enable 'with SparkContext(...) as sc: app(sc)' syntax. + """ + return self + + def __exit__(self, type, value, trace): + """ + Enable 'with SparkContext(...) as sc: app' syntax. + + Specifically stop the context on exit of the with block. + """ + self.stop() + @classmethod def setSystemProperty(cls, key, value): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 0bd2a9e6c507d..bb84ebe72cb24 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1254,6 +1254,35 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) +class ContextStopTests(unittest.TestCase): + + def test_stop(self): + sc = SparkContext() + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_exception(self): + try: + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + raise Exception() + except: + pass + self.assertEqual(SparkContext._active_spark_context, None) + + def test_with_stop(self): + with SparkContext() as sc: + self.assertNotEqual(SparkContext._active_spark_context, None) + sc.stop() + self.assertEqual(SparkContext._active_spark_context, None) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase):