Skip to content

Commit 552b38f

Browse files
Davies Liudavies
authored andcommitted
[SPARK-12380] [PYSPARK] use SQLContext.getOrCreate in mllib
MLlib should use SQLContext.getOrCreate() instead of creating new SQLContext. Author: Davies Liu <[email protected]> Closes #10338 from davies/create_context. (cherry picked from commit 27b98e9) Signed-off-by: Davies Liu <[email protected]>
1 parent 04e868b commit 552b38f

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

python/pyspark/mllib/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _java2py(sc, r, encoding="bytes"):
102102
return RDD(jrdd, sc)
103103

104104
if clsName == 'DataFrame':
105-
return DataFrame(r, SQLContext(sc))
105+
return DataFrame(r, SQLContext.getOrCreate(sc))
106106

107107
if clsName in _picklable_classes:
108108
r = sc._jvm.SerDe.dumps(r)
@@ -125,7 +125,7 @@ def callJavaFunc(sc, func, *args):
125125

126126
def callMLlibFunc(name, *args):
127127
""" Call API in PythonMLLibAPI """
128-
sc = SparkContext._active_spark_context
128+
sc = SparkContext.getOrCreate()
129129
api = getattr(sc._jvm.PythonMLLibAPI(), name)
130130
return callJavaFunc(sc, api, *args)
131131

@@ -135,7 +135,7 @@ class JavaModelWrapper(object):
135135
Wrapper for the model in JVM
136136
"""
137137
def __init__(self, java_model):
138-
self._sc = SparkContext._active_spark_context
138+
self._sc = SparkContext.getOrCreate()
139139
self._java_model = java_model
140140

141141
def __del__(self):

python/pyspark/mllib/evaluation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
4444

4545
def __init__(self, scoreAndLabels):
4646
sc = scoreAndLabels.ctx
47-
sql_ctx = SQLContext(sc)
47+
sql_ctx = SQLContext.getOrCreate(sc)
4848
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
4949
StructField("score", DoubleType(), nullable=False),
5050
StructField("label", DoubleType(), nullable=False)]))
@@ -103,7 +103,7 @@ class RegressionMetrics(JavaModelWrapper):
103103

104104
def __init__(self, predictionAndObservations):
105105
sc = predictionAndObservations.ctx
106-
sql_ctx = SQLContext(sc)
106+
sql_ctx = SQLContext.getOrCreate(sc)
107107
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
108108
StructField("prediction", DoubleType(), nullable=False),
109109
StructField("observation", DoubleType(), nullable=False)]))
@@ -197,7 +197,7 @@ class MulticlassMetrics(JavaModelWrapper):
197197

198198
def __init__(self, predictionAndLabels):
199199
sc = predictionAndLabels.ctx
200-
sql_ctx = SQLContext(sc)
200+
sql_ctx = SQLContext.getOrCreate(sc)
201201
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
202202
StructField("prediction", DoubleType(), nullable=False),
203203
StructField("label", DoubleType(), nullable=False)]))
@@ -338,7 +338,7 @@ class RankingMetrics(JavaModelWrapper):
338338

339339
def __init__(self, predictionAndLabels):
340340
sc = predictionAndLabels.ctx
341-
sql_ctx = SQLContext(sc)
341+
sql_ctx = SQLContext.getOrCreate(sc)
342342
df = sql_ctx.createDataFrame(predictionAndLabels,
343343
schema=sql_ctx._inferSchema(predictionAndLabels))
344344
java_model = callMLlibFunc("newRankingMetrics", df._jdf)
@@ -424,7 +424,7 @@ class MultilabelMetrics(JavaModelWrapper):
424424

425425
def __init__(self, predictionAndLabels):
426426
sc = predictionAndLabels.ctx
427-
sql_ctx = SQLContext(sc)
427+
sql_ctx = SQLContext.getOrCreate(sc)
428428
df = sql_ctx.createDataFrame(predictionAndLabels,
429429
schema=sql_ctx._inferSchema(predictionAndLabels))
430430
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics

python/pyspark/mllib/feature.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from py4j.protocol import Py4JJavaError
3232

33-
from pyspark import SparkContext, since
33+
from pyspark import since
3434
from pyspark.rdd import RDD, ignore_unicode_prefix
3535
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
3636
from pyspark.mllib.linalg import (
@@ -100,8 +100,6 @@ def transform(self, vector):
100100
:return: normalized vector. If the norm of the input is zero, it
101101
will return the input vector.
102102
"""
103-
sc = SparkContext._active_spark_context
104-
assert sc is not None, "SparkContext should be initialized first"
105103
if isinstance(vector, RDD):
106104
vector = vector.map(_convert_to_vector)
107105
else:

0 commit comments

Comments
 (0)