Skip to content

Commit 10ba6e1

Browse files
committed
[SPARK-1087] Move python traceback utilities into new traceback_utils.py file.
1 parent 2aea0da commit 10ba6e1

File tree

3 files changed

+86
-57
lines changed

3 files changed

+86
-57
lines changed

python/pyspark/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pyspark.storagelevel import StorageLevel
3434
from pyspark import rdd
3535
from pyspark.rdd import RDD
36+
from pyspark.traceback_utils import extract_concise_traceback
3637

3738
from py4j.java_collections import ListConverter
3839

@@ -99,8 +100,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
99100
...
100101
ValueError:...
101102
"""
102-
if rdd._extract_concise_traceback() is not None:
103-
self._callsite = rdd._extract_concise_traceback()
103+
if extract_concise_traceback() is not None:
104+
self._callsite = extract_concise_traceback()
104105
else:
105106
tempNamedTuple = namedtuple("Callsite", "function file linenum")
106107
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)

python/pyspark/rdd.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
from base64 import standard_b64encode as b64enc
1919
import copy
2020
from collections import defaultdict
21-
from collections import namedtuple
2221
from itertools import chain, ifilter, imap
2322
import operator
2423
import os
2524
import sys
2625
import shlex
27-
import traceback
2826
from subprocess import Popen, PIPE
2927
from tempfile import NamedTemporaryFile
3028
from threading import Thread
@@ -45,6 +43,7 @@
4543
from pyspark.resultiterable import ResultIterable
4644
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
4745
get_used_memory, ExternalSorter
46+
from pyspark.traceback_utils import JavaStackTrace
4847

4948
from py4j.java_collections import ListConverter, MapConverter
5049

@@ -81,57 +80,6 @@ def portable_hash(x):
8180
return hash(x)
8281

8382

84-
def _extract_concise_traceback():
85-
"""
86-
This function returns the traceback info for a callsite, returns a dict
87-
with function name, file name and line number
88-
"""
89-
tb = traceback.extract_stack()
90-
callsite = namedtuple("Callsite", "function file linenum")
91-
if len(tb) == 0:
92-
return None
93-
file, line, module, what = tb[len(tb) - 1]
94-
sparkpath = os.path.dirname(file)
95-
first_spark_frame = len(tb) - 1
96-
for i in range(0, len(tb)):
97-
file, line, fun, what = tb[i]
98-
if file.startswith(sparkpath):
99-
first_spark_frame = i
100-
break
101-
if first_spark_frame == 0:
102-
file, line, fun, what = tb[0]
103-
return callsite(function=fun, file=file, linenum=line)
104-
sfile, sline, sfun, swhat = tb[first_spark_frame]
105-
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
106-
return callsite(function=sfun, file=ufile, linenum=uline)
107-
108-
_spark_stack_depth = 0
109-
110-
111-
class _JavaStackTrace(object):
112-
113-
def __init__(self, sc):
114-
tb = _extract_concise_traceback()
115-
if tb is not None:
116-
self._traceback = "%s at %s:%s" % (
117-
tb.function, tb.file, tb.linenum)
118-
else:
119-
self._traceback = "Error! Could not extract traceback info"
120-
self._context = sc
121-
122-
def __enter__(self):
123-
global _spark_stack_depth
124-
if _spark_stack_depth == 0:
125-
self._context._jsc.setCallSite(self._traceback)
126-
_spark_stack_depth += 1
127-
128-
def __exit__(self, type, value, tb):
129-
global _spark_stack_depth
130-
_spark_stack_depth -= 1
131-
if _spark_stack_depth == 0:
132-
self._context._jsc.setCallSite(None)
133-
134-
13583
class BoundedFloat(float):
13684
"""
13785
Bounded value is generated by approximate job, with confidence and low
@@ -704,7 +652,7 @@ def collect(self):
704652
"""
705653
Return a list that contains all of the elements in this RDD.
706654
"""
707-
with _JavaStackTrace(self.context) as st:
655+
with JavaStackTrace(self.context) as st:
708656
bytesInJava = self._jrdd.collect().iterator()
709657
return list(self._collect_iterator_through_file(bytesInJava))
710658

@@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator):
15151463

15161464
keyed = self.mapPartitionsWithIndex(add_shuffle_key)
15171465
keyed._bypass_serializer = True
1518-
with _JavaStackTrace(self.context) as st:
1466+
with JavaStackTrace(self.context) as st:
15191467
pairRDD = self.ctx._jvm.PairwiseRDD(
15201468
keyed._jrdd.rdd()).asJavaPairRDD()
15211469
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,

python/pyspark/traceback_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from collections import namedtuple
19+
import os
20+
import traceback
21+
22+
23+
__all__ = ["extract_concise_traceback", "SparkContext"]
24+
25+
26+
def extract_concise_traceback():
27+
"""
28+
This function returns the traceback info for a callsite, returns a dict
29+
with function name, file name and line number
30+
"""
31+
tb = traceback.extract_stack()
32+
callsite = namedtuple("Callsite", "function file linenum")
33+
if len(tb) == 0:
34+
return None
35+
file, line, module, what = tb[len(tb) - 1]
36+
sparkpath = os.path.dirname(file)
37+
first_spark_frame = len(tb) - 1
38+
for i in range(0, len(tb)):
39+
file, line, fun, what = tb[i]
40+
if file.startswith(sparkpath):
41+
first_spark_frame = i
42+
break
43+
if first_spark_frame == 0:
44+
file, line, fun, what = tb[0]
45+
return callsite(function=fun, file=file, linenum=line)
46+
sfile, sline, sfun, swhat = tb[first_spark_frame]
47+
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
48+
return callsite(function=sfun, file=ufile, linenum=uline)
49+
50+
51+
class JavaStackTrace(object):
52+
"""
53+
Helper for setting the spark context call site.
54+
55+
Example usage:
56+
from pyspark.context import JavaStackTrace
57+
with JavaStackTrace(<relevant SparkContext>) as st:
58+
<a Spark call>
59+
"""
60+
61+
_spark_stack_depth = 0
62+
63+
def __init__(self, sc):
64+
tb = extract_concise_traceback()
65+
if tb is not None:
66+
self._traceback = "%s at %s:%s" % (
67+
tb.function, tb.file, tb.linenum)
68+
else:
69+
self._traceback = "Error! Could not extract traceback info"
70+
self._context = sc
71+
72+
def __enter__(self):
73+
if JavaStackTrace._spark_stack_depth == 0:
74+
self._context._jsc.setCallSite(self._traceback)
75+
JavaStackTrace._spark_stack_depth += 1
76+
77+
def __exit__(self, type, value, tb):
78+
JavaStackTrace._spark_stack_depth -= 1
79+
if JavaStackTrace._spark_stack_depth == 0:
80+
self._context._jsc.setCallSite(None)

0 commit comments

Comments
 (0)