|
18 | 18 | from base64 import standard_b64encode as b64enc |
19 | 19 | import copy |
20 | 20 | from collections import defaultdict |
21 | | -from collections import namedtuple |
22 | 21 | from itertools import chain, ifilter, imap |
23 | 22 | import operator |
24 | 23 | import os |
25 | 24 | import sys |
26 | 25 | import shlex |
27 | | -import traceback |
28 | 26 | from subprocess import Popen, PIPE |
29 | 27 | from tempfile import NamedTemporaryFile |
30 | 28 | from threading import Thread |
|
45 | 43 | from pyspark.resultiterable import ResultIterable |
46 | 44 | from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ |
47 | 45 | get_used_memory, ExternalSorter |
| 46 | +from pyspark.traceback_utils import JavaStackTrace |
48 | 47 |
|
49 | 48 | from py4j.java_collections import ListConverter, MapConverter |
50 | 49 |
|
@@ -81,57 +80,6 @@ def portable_hash(x): |
81 | 80 | return hash(x) |
82 | 81 |
|
83 | 82 |
|
84 | | -def _extract_concise_traceback(): |
85 | | - """ |
86 | | - This function returns the traceback info for a callsite, returns a dict |
87 | | - with function name, file name and line number |
88 | | - """ |
89 | | - tb = traceback.extract_stack() |
90 | | - callsite = namedtuple("Callsite", "function file linenum") |
91 | | - if len(tb) == 0: |
92 | | - return None |
93 | | - file, line, module, what = tb[len(tb) - 1] |
94 | | - sparkpath = os.path.dirname(file) |
95 | | - first_spark_frame = len(tb) - 1 |
96 | | - for i in range(0, len(tb)): |
97 | | - file, line, fun, what = tb[i] |
98 | | - if file.startswith(sparkpath): |
99 | | - first_spark_frame = i |
100 | | - break |
101 | | - if first_spark_frame == 0: |
102 | | - file, line, fun, what = tb[0] |
103 | | - return callsite(function=fun, file=file, linenum=line) |
104 | | - sfile, sline, sfun, swhat = tb[first_spark_frame] |
105 | | - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] |
106 | | - return callsite(function=sfun, file=ufile, linenum=uline) |
107 | | - |
108 | | -_spark_stack_depth = 0 |
109 | | - |
110 | | - |
111 | | -class _JavaStackTrace(object): |
112 | | - |
113 | | - def __init__(self, sc): |
114 | | - tb = _extract_concise_traceback() |
115 | | - if tb is not None: |
116 | | - self._traceback = "%s at %s:%s" % ( |
117 | | - tb.function, tb.file, tb.linenum) |
118 | | - else: |
119 | | - self._traceback = "Error! Could not extract traceback info" |
120 | | - self._context = sc |
121 | | - |
122 | | - def __enter__(self): |
123 | | - global _spark_stack_depth |
124 | | - if _spark_stack_depth == 0: |
125 | | - self._context._jsc.setCallSite(self._traceback) |
126 | | - _spark_stack_depth += 1 |
127 | | - |
128 | | - def __exit__(self, type, value, tb): |
129 | | - global _spark_stack_depth |
130 | | - _spark_stack_depth -= 1 |
131 | | - if _spark_stack_depth == 0: |
132 | | - self._context._jsc.setCallSite(None) |
133 | | - |
134 | | - |
135 | 83 | class BoundedFloat(float): |
136 | 84 | """ |
137 | 85 | Bounded value is generated by approximate job, with confidence and low |
@@ -704,7 +652,7 @@ def collect(self): |
704 | 652 | """ |
705 | 653 | Return a list that contains all of the elements in this RDD. |
706 | 654 | """ |
707 | | - with _JavaStackTrace(self.context) as st: |
| 655 | + with JavaStackTrace(self.context) as st: |
708 | 656 | bytesInJava = self._jrdd.collect().iterator() |
709 | 657 | return list(self._collect_iterator_through_file(bytesInJava)) |
710 | 658 |
|
@@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator): |
1515 | 1463 |
|
1516 | 1464 | keyed = self.mapPartitionsWithIndex(add_shuffle_key) |
1517 | 1465 | keyed._bypass_serializer = True |
1518 | | - with _JavaStackTrace(self.context) as st: |
| 1466 | + with JavaStackTrace(self.context) as st: |
1519 | 1467 | pairRDD = self.ctx._jvm.PairwiseRDD( |
1520 | 1468 | keyed._jrdd.rdd()).asJavaPairRDD() |
1521 | 1469 | partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, |
|
0 commit comments