Skip to content

Commit 3d858a2

Browse files
jkbradleytechaddict
authored andcommitted
Fixing copy bug (#1)
* moved copy from JavaModel to JavaParams. mv del from JavaModel to JavaWrapper * added test which fails before this fix
1 parent f25b099 commit 3d858a2

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

python/pyspark/ml/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,24 @@ def test_word2vec_param(self):
390390
self.assertEqual(model.getWindowSize(), 6)
391391

392392

393+
class EvaluatorTests(SparkSessionTestCase):
394+
395+
def test_java_params(self):
396+
"""
397+
This tests a bug fixed by SPARK-18274 which causes multiple copies
398+
of a Params instance in Python to be linked to the same Java instance.
399+
"""
400+
evaluator = RegressionEvaluator(metricName="r2")
401+
df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
402+
evaluator.evaluate(df)
403+
self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
404+
evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"})
405+
evaluator.evaluate(df)
406+
evaluatorCopy.evaluate(df)
407+
self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
408+
self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
409+
410+
393411
class FeatureTests(SparkSessionTestCase):
394412

395413
def test_binarizer(self):

python/pyspark/ml/wrapper.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def __init__(self, java_obj=None):
3333
super(JavaWrapper, self).__init__()
3434
self._java_obj = java_obj
3535

36+
def __del__(self):
37+
SparkContext._active_spark_context._gateway.detach(self._java_obj)
38+
3639
@classmethod
3740
def _create_from_java_class(cls, java_class, *args):
3841
"""
@@ -180,6 +183,24 @@ def __get_class(clazz):
180183
% stage_name)
181184
return py_stage
182185

186+
def copy(self, extra=None):
187+
"""
188+
Creates a copy of this instance with the same uid and some
189+
extra params. This implementation first calls Params.copy and
190+
then make a copy of the companion Java model with extra params.
191+
So both the Python wrapper and the Java model get copied.
192+
193+
:param extra: Extra parameters to copy to the new instance
194+
:return: Copy of this instance
195+
"""
196+
if extra is None:
197+
extra = dict()
198+
that = super(JavaParams, self).copy(extra)
199+
if self._java_obj is not None:
200+
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
201+
that._transfer_params_to_java()
202+
return that
203+
183204

184205
@inherit_doc
185206
class JavaEstimator(JavaParams, Estimator):
@@ -256,25 +277,3 @@ def __init__(self, java_model=None):
256277
super(JavaModel, self).__init__(java_model)
257278
if java_model is not None:
258279
self._resetUid(java_model.uid())
259-
260-
def __del__(self):
261-
if SparkContext._gateway:
262-
SparkContext._gateway.detach(self._java_obj)
263-
264-
def copy(self, extra=None):
265-
"""
266-
Creates a copy of this instance with the same uid and some
267-
extra params. This implementation first calls Params.copy and
268-
then make a copy of the companion Java model with extra params.
269-
So both the Python wrapper and the Java model get copied.
270-
271-
:param extra: Extra parameters to copy to the new instance
272-
:return: Copy of this instance
273-
"""
274-
if extra is None:
275-
extra = dict()
276-
that = super(JavaModel, self).copy(extra)
277-
if self._java_obj is not None:
278-
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
279-
that._transfer_params_to_java()
280-
return that

0 commit comments

Comments
 (0)