diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6dc1721f56adf..a577194a48006 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -168,6 +168,21 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + /** + * A helper function to collect an RDD as an iterator, then serve it via socket. + * This method is similar with `PythonRDD.collectAndServe`, but user can specify job group id, + * job description, and interruptOnCancel option. + */ + def collectAndServeWithJobGroup[T]( + rdd: RDD[T], + groupId: String, + description: String, + interruptOnCancel: Boolean): Array[Any] = { + val sc = rdd.sparkContext + sc.setJobGroup(groupId, description, interruptOnCancel) + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") + } + /** * A helper function to create a local RDD iterator and serve it via socket. Partitions are * are collected as separate jobs, by order of index. Partition data is first requested by a diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fbf645d10ee86..d0ac000ba3208 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -877,6 +877,19 @@ def collect(self): sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) return list(_load_from_socket(sock_info, self._jrdd_deserializer)) + def collectWithJobGroup(self, groupId, description, interruptOnCancel=False): + """ + .. note:: Experimental + + When collect rdd, use this method to specify job group. + + .. versionadded:: 3.0.0 + """ + with SCCallSiteSync(self.context) as css: + sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup( + self._jrdd.rdd(), groupId, description, interruptOnCancel) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) + def reduce(self, f): """ Reduces the elements of this RDD using the specified commutative and diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 31c5a7510a165..62ad4221d7078 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -814,6 +814,68 @@ def assert_request_contents(exec_reqs, task_reqs): rddWithoutRp = self.sc.parallelize(range(10)) self.assertEqual(rddWithoutRp.getResourceProfile(), None) + def test_multiple_group_jobs(self): + import threading + group_a = "job_ids_to_cancel" + group_b = "job_ids_to_run" + + threads = [] + thread_ids = range(4) + thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0] + thread_ids_to_run = [i for i in thread_ids if i % 2 != 0] + + # A list which records whether job is cancelled. + # The index of the array is the thread index which job run in. + is_job_cancelled = [False for _ in thread_ids] + + def run_job(job_group, index): + """ + Executes a job with the group ``job_group``. Each job waits for 3 seconds + and then exits. + """ + try: + self.sc.parallelize([15]).map(lambda x: time.sleep(x)) \ + .collectWithJobGroup(job_group, "test rdd collect with setting job group") + is_job_cancelled[index] = False + except Exception: + # Assume that exception means job cancellation. + is_job_cancelled[index] = True + + # Test if job succeeded when not cancelled. + run_job(group_a, 0) + self.assertFalse(is_job_cancelled[0]) + + # Run jobs + for i in thread_ids_to_cancel: + t = threading.Thread(target=run_job, args=(group_a, i)) + t.start() + threads.append(t) + + for i in thread_ids_to_run: + t = threading.Thread(target=run_job, args=(group_b, i)) + t.start() + threads.append(t) + + # Wait to make sure all jobs are executed. + time.sleep(3) + # And then, cancel one job group. + self.sc.cancelJobGroup(group_a) + + # Wait until all threads launching jobs are finished. + for t in threads: + t.join() + + for i in thread_ids_to_cancel: + self.assertTrue( + is_job_cancelled[i], + "Thread {i}: Job in group A was not cancelled.".format(i=i)) + + for i in thread_ids_to_run: + self.assertFalse( + is_job_cancelled[i], + "Thread {i}: Job in group B did not succeeded.".format(i=i)) + + if __name__ == "__main__": import unittest from pyspark.tests.test_rdd import *