Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

/** A Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = {
this(uid, Metadata.empty, models.asScala.toArray)
}

/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ object MimaExcludes {
// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"),

// [SPARK-17161] Removing Python-friendly constructors not needed
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this")
)

// Exclude rules for 2.1.x
Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
>>> model.transform(test2).head().prediction
2.0
>>> model_path = temp_path + "/ovr_model"
>>> model.save(model_path)
>>> model2 = OneVsRestModel.load(model_path)
>>> model2.transform(test0).head().prediction
1.0

.. versionadded:: 2.0.0
"""
Expand Down Expand Up @@ -1630,9 +1635,13 @@ def _to_java(self):

:return: Java object equivalent to this instance.
"""
sc = SparkContext._active_spark_context
java_models = [model._to_java() for model in self.models]
java_models_array = JavaWrapper._new_java_array(
java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
self.uid, java_models)
self.uid, metadata.empty(), java_models_array)
_java_obj.set("classifier", self.getClassifier()._to_java())
_java_obj.set("featuresCol", self.getFeaturesCol())
_java_obj.set("labelCol", self.getLabelCol())
Expand Down
40 changes: 38 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
GeneralizedLinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import _java2py
from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.ml.common import _java2py, _py2java
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
Expand Down Expand Up @@ -1620,6 +1620,42 @@ def test_infer_schema(self):
raise ValueError("Expected a matrix but got type %r" % type(m))


class WrapperTests(MLlibTestCase):

def test_new_java_array(self):
# test array of strings
str_list = ["a", "b", "c"]
java_class = self.sc._gateway.jvm.java.lang.String
java_array = JavaWrapper._new_java_array(str_list, java_class)
self.assertEqual(_java2py(self.sc, java_array), str_list)
# test array of integers
int_list = [1, 2, 3]
java_class = self.sc._gateway.jvm.java.lang.Integer
java_array = JavaWrapper._new_java_array(int_list, java_class)
self.assertEqual(_java2py(self.sc, java_array), int_list)
# test array of floats
float_list = [0.1, 0.2, 0.3]
java_class = self.sc._gateway.jvm.java.lang.Double
java_array = JavaWrapper._new_java_array(float_list, java_class)
self.assertEqual(_java2py(self.sc, java_array), float_list)
# test array of bools
bool_list = [False, True, True]
java_class = self.sc._gateway.jvm.java.lang.Boolean
java_array = JavaWrapper._new_java_array(bool_list, java_class)
self.assertEqual(_java2py(self.sc, java_array), bool_list)
# test array of Java DenseVectors
v1 = DenseVector([0.0, 1.0])
v2 = DenseVector([1.0, 0.0])
vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
# test empty array
java_class = self.sc._gateway.jvm.java.lang.Integer
java_array = JavaWrapper._new_java_array([], java_class)
self.assertEqual(_java2py(self.sc, java_array), [])


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#

from abc import ABCMeta, abstractmethod
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these imports needed anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good catch thanks! I wasn't using the def for basestring but I still am iterating the array/list with xrange. I was thinking it would be possible for someone to use a large list, like when using a vocabulary. I suppose I could use enumerate if you think that's better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this look ok now @holdenk?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks fine, we use xrange similarly elsewhere.

if sys.version >= '3':
xrange = range

from pyspark import SparkContext
from pyspark.sql import DataFrame
Expand Down Expand Up @@ -59,6 +62,32 @@ def _new_java_obj(java_class, *args):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)

@staticmethod
def _new_java_array(pylist, java_class):
"""
Create a Java array of given java_class type. Useful for
calling a method with a Scala Array from Python with Py4J.

:param pylist:
Python list to convert to a Java Array.
:param java_class:
Java class to specify the type of Array. Should be in the
form of sc._gateway.jvm.* (sc is a valid Spark Context).
:return:
Java Array of converted pylist.

Example primitive Java classes:
- basestring -> sc._gateway.jvm.java.lang.String
- int -> sc._gateway.jvm.java.lang.Integer
- float -> sc._gateway.jvm.java.lang.Double
- bool -> sc._gateway.jvm.java.lang.Boolean
"""
sc = SparkContext._active_spark_context
java_array = sc._gateway.new_array(java_class, len(pylist))
for i in xrange(len(pylist)):
java_array[i] = pylist[i]
return java_array


@inherit_doc
class JavaParams(JavaWrapper, Params):
Expand Down