Skip to content

Commit 57d70d2

Browse files
BryanCutlerholdenk
authored andcommitted
[SPARK-17161][PYSPARK][ML] Add PySpark-ML JavaWrapper convenience function to create Py4J JavaArrays
## What changes were proposed in this pull request? Adding convenience function to Python `JavaWrapper` so that it is easy to create a Py4J JavaArray that is compatible with current class constructors that have a Scala `Array` as input so that it is not necessary to have a Java/Python friendly constructor. The function takes a Java class as input that is used by Py4J to create the Java array of the given class. As an example, `OneVsRest` has been updated to use this and the alternate constructor is removed. ## How was this patch tested? Added unit tests for the new convenience function and updated `OneVsRest` doctests which use this to persist the model. Author: Bryan Cutler <[email protected]> Closes #14725 from BryanCutler/pyspark-new_java_array-CountVectorizer-SPARK-17161.
1 parent ce112ce commit 57d70d2

File tree

5 files changed

+81
-9
lines changed

5 files changed

+81
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,6 @@ final class OneVsRestModel private[ml] (
135135
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
136136
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
137137

138-
/** A Python-friendly auxiliary constructor. */
139-
private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = {
140-
this(uid, Metadata.empty, models.asScala.toArray)
141-
}
142-
143138
/** @group setParam */
144139
@Since("2.1.0")
145140
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

project/MimaExcludes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ object MimaExcludes {
5454
// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
5555
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
5656
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
57-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11")
57+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"),
58+
59+
// [SPARK-17161] Removing Python-friendly constructors not needed
60+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this")
5861
)
5962

6063
// Exclude rules for 2.1.x

python/pyspark/ml/classification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
15171517
>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
15181518
>>> model.transform(test2).head().prediction
15191519
2.0
1520+
>>> model_path = temp_path + "/ovr_model"
1521+
>>> model.save(model_path)
1522+
>>> model2 = OneVsRestModel.load(model_path)
1523+
>>> model2.transform(test0).head().prediction
1524+
1.0
15201525
15211526
.. versionadded:: 2.0.0
15221527
"""
@@ -1759,9 +1764,13 @@ def _to_java(self):
17591764
17601765
:return: Java object equivalent to this instance.
17611766
"""
1767+
sc = SparkContext._active_spark_context
17621768
java_models = [model._to_java() for model in self.models]
1769+
java_models_array = JavaWrapper._new_java_array(
1770+
java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
1771+
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
17631772
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
1764-
self.uid, java_models)
1773+
self.uid, metadata.empty(), java_models_array)
17651774
_java_obj.set("classifier", self.getClassifier()._to_java())
17661775
_java_obj.set("featuresCol", self.getFeaturesCol())
17671776
_java_obj.set("labelCol", self.getLabelCol())

python/pyspark/ml/tests.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
6161
GeneralizedLinearRegression
6262
from pyspark.ml.tuning import *
63-
from pyspark.ml.wrapper import JavaParams
64-
from pyspark.ml.common import _java2py
63+
from pyspark.ml.wrapper import JavaParams, JavaWrapper
64+
from pyspark.ml.common import _java2py, _py2java
6565
from pyspark.serializers import PickleSerializer
6666
from pyspark.sql import DataFrame, Row, SparkSession
6767
from pyspark.sql.functions import rand
@@ -1620,6 +1620,42 @@ def test_infer_schema(self):
16201620
raise ValueError("Expected a matrix but got type %r" % type(m))
16211621

16221622

1623+
class WrapperTests(MLlibTestCase):
1624+
1625+
def test_new_java_array(self):
1626+
# test array of strings
1627+
str_list = ["a", "b", "c"]
1628+
java_class = self.sc._gateway.jvm.java.lang.String
1629+
java_array = JavaWrapper._new_java_array(str_list, java_class)
1630+
self.assertEqual(_java2py(self.sc, java_array), str_list)
1631+
# test array of integers
1632+
int_list = [1, 2, 3]
1633+
java_class = self.sc._gateway.jvm.java.lang.Integer
1634+
java_array = JavaWrapper._new_java_array(int_list, java_class)
1635+
self.assertEqual(_java2py(self.sc, java_array), int_list)
1636+
# test array of floats
1637+
float_list = [0.1, 0.2, 0.3]
1638+
java_class = self.sc._gateway.jvm.java.lang.Double
1639+
java_array = JavaWrapper._new_java_array(float_list, java_class)
1640+
self.assertEqual(_java2py(self.sc, java_array), float_list)
1641+
# test array of bools
1642+
bool_list = [False, True, True]
1643+
java_class = self.sc._gateway.jvm.java.lang.Boolean
1644+
java_array = JavaWrapper._new_java_array(bool_list, java_class)
1645+
self.assertEqual(_java2py(self.sc, java_array), bool_list)
1646+
# test array of Java DenseVectors
1647+
v1 = DenseVector([0.0, 1.0])
1648+
v2 = DenseVector([1.0, 0.0])
1649+
vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
1650+
java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
1651+
java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
1652+
self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
1653+
# test empty array
1654+
java_class = self.sc._gateway.jvm.java.lang.Integer
1655+
java_array = JavaWrapper._new_java_array([], java_class)
1656+
self.assertEqual(_java2py(self.sc, java_array), [])
1657+
1658+
16231659
if __name__ == "__main__":
16241660
from pyspark.ml.tests import *
16251661
if xmlrunner:

python/pyspark/ml/wrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#
1717

1818
from abc import ABCMeta, abstractmethod
19+
import sys
20+
if sys.version >= '3':
21+
xrange = range
1922

2023
from pyspark import SparkContext
2124
from pyspark.sql import DataFrame
@@ -59,6 +62,32 @@ def _new_java_obj(java_class, *args):
5962
java_args = [_py2java(sc, arg) for arg in args]
6063
return java_obj(*java_args)
6164

65+
@staticmethod
66+
def _new_java_array(pylist, java_class):
67+
"""
68+
Create a Java array of given java_class type. Useful for
69+
calling a method with a Scala Array from Python with Py4J.
70+
71+
:param pylist:
72+
Python list to convert to a Java Array.
73+
:param java_class:
74+
Java class to specify the type of Array. Should be in the
75+
form of sc._gateway.jvm.* (sc is a valid Spark Context).
76+
:return:
77+
Java Array of converted pylist.
78+
79+
Example primitive Java classes:
80+
- basestring -> sc._gateway.jvm.java.lang.String
81+
- int -> sc._gateway.jvm.java.lang.Integer
82+
- float -> sc._gateway.jvm.java.lang.Double
83+
- bool -> sc._gateway.jvm.java.lang.Boolean
84+
"""
85+
sc = SparkContext._active_spark_context
86+
java_array = sc._gateway.new_array(java_class, len(pylist))
87+
for i in xrange(len(pylist)):
88+
java_array[i] = pylist[i]
89+
return java_array
90+
6291

6392
@inherit_doc
6493
class JavaParams(JavaWrapper, Params):

0 commit comments

Comments
 (0)