diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index ce5725764be6..f0c618e0650b 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -403,6 +403,7 @@ def __hash__(self): "pyspark.ml.classification", "pyspark.ml.clustering", "pyspark.ml.linalg.__init__", + "pyspark.ml.pipeline", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/mllib/src/main/scala/org/apache/spark/ml/api/python/PythonStage.scala b/mllib/src/main/scala/org/apache/spark/ml/api/python/PythonStage.scala new file mode 100644 index 000000000000..664e62cc846d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/api/python/PythonStage.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy + +import scala.reflect._ +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.json4s._ + +import org.apache.spark.SparkException +import org.apache.spark.ml.{Estimator, Model, PipelineStage, Transformer} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * Wrapper of PipelineStage (Estimator/Model/Transformer) written in pure Python, which + * implementation is in PySpark. See pyspark.ml.util.StageWrapper + */ +private[python] trait PythonStageWrapper { + + def getUid: String + + def fit(dataset: Dataset[_]): PythonStageWrapper + + def transform(dataset: Dataset[_]): DataFrame + + def transformSchema(schema: StructType): StructType + + def getStage: Array[Byte] + + def getClassName: String + + def save(path: String): Unit + + def copy(extra: ParamMap): PythonStageWrapper + + /** + * Get the failure in PySpark, if any. + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String +} + +/** + * ML Reader for Python PipelineStages. The implementation of the reader is in Python, which is + * registered here the moment we creating a new PythonStageWrapper. + */ +private[python] object PythonStageWrapper { + private var reader: PythonStageReader = _ + + /** + * Register Python stage reader to load PySpark PipelineStages. + */ + def registerReader(r: PythonStageReader): Unit = { + reader = r + } + + /** + * Load a Python PipelineStage given its path and class name. + */ + def load(path: String, clazz: String): PythonStageWrapper = { + require(reader != null, "Python reader has not been registered.") + callLoadFromPython(path, clazz) + } + + private def callLoadFromPython(path: String, clazz: String): PythonStageWrapper = { + val result = reader.load(path, clazz) + val failure = reader.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + result + } +} + +/** + * Reader to load a pure Python PipelineStage. Its implementation is in PySpark. + * See pyspark.ml.util.StageReader + */ +private[python] trait PythonStageReader { + + def getLastFailure: String + + def load(path: String, clazz: String): PythonStageWrapper +} + +/** + * Serializer of a pure Python PipelineStage. Its implementation is in Pyspark. + * See pyspark.ml.util.StageSerializer + */ +private[python] trait PythonStageSerializer { + + def dumps(id: String): Array[Byte] + + def loads(bytes: Array[Byte]): PythonStageWrapper + + def getLastFailure: String +} + +/** + * Helpers for PythonStageSerializer. + */ +private[python] object PythonStageSerializer { + + /** + * A serializer in Python, used to serialize PythonStageWrapper. + */ + private var serializer: PythonStageSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonStageSerializer): Unit = synchronized { + serializer = ser + } + + def serialize(wrapper: PythonStageWrapper): Array[Byte] = synchronized { + require(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(wrapper.asInstanceOf[Proxy]) + val f = h.getClass.getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results + } + + def deserialize(bytes: Array[Byte]): PythonStageWrapper = synchronized { + require(serializer != null, "Serializer has not been registered!") + val wrapper = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + wrapper + } +} + +/** + * A proxy estimator for all PySpark estimator written in pure Python. + */ +class PythonEstimator(@transient private var proxy: PythonStageWrapper) + extends Estimator[PythonModel] with PythonStageBase with MLWritable { + + override val uid: String = proxy.getUid + + private[python] override def getProxy = this.proxy + + override def fit(dataset: Dataset[_]): PythonModel = { + val modelWrapper = callFromPython(proxy.fit(dataset)) + new PythonModel(modelWrapper) + } + + override def copy(extra: ParamMap): Estimator[PythonModel] = { + this.proxy = callFromPython(proxy.copy(extra)) + this + } + + override def transformSchema(schema: StructType): StructType = { + callFromPython(proxy.transformSchema(schema)) + } + + override def write: MLWriter = new PythonEstimator.PythonEstimatorWriter(this) + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + proxy = PythonStageSerializer.deserialize(bytes) + } +} + +object PythonEstimator extends MLReadable[PythonEstimator] { + + override def read: MLReader[PythonEstimator] = new PythonEstimatorReader + + override def load(path: String): PythonEstimator = super.load(path) + + private[python] class PythonEstimatorWriter(instance: PythonEstimator) + extends PythonStage.Writer[PythonEstimator](instance) + + private class PythonEstimatorReader extends PythonStage.Reader[PythonEstimator] +} + +/** + * A proxy model of all PySpark Model written in pure Python. + */ +class PythonModel(@transient private var proxy: PythonStageWrapper) + extends Model[PythonModel] with PythonStageBase with MLWritable { + + override val uid: String = proxy.getUid + + private[python] override def getProxy = this.proxy + + override def copy(extra: ParamMap): PythonModel = { + this.proxy = callFromPython(proxy.copy(extra)) + this + } + + override def transform(dataset: Dataset[_]): DataFrame = { + callFromPython(proxy.transform(dataset)) + } + + override def transformSchema(schema: StructType): StructType = { + callFromPython(proxy.transformSchema(schema)) + } + + override def write: MLWriter = new PythonModel.PythonModelWriter(this) + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + proxy = PythonStageSerializer.deserialize(bytes) + } +} + +object PythonModel extends MLReadable[PythonModel] { + + override def read: MLReader[PythonModel] = new PythonModelReader + + override def load(path: String): PythonModel = super.load(path) + + private[python] class PythonModelWriter(instance: PythonModel) + extends PythonStage.Writer[PythonModel](instance) + + private class PythonModelReader extends PythonStage.Reader[PythonModel] +} + +/** + * A proxy transformer for all PySpark transformers written in pure Python. + */ +class PythonTransformer(@transient private var proxy: PythonStageWrapper) + extends Transformer with PythonStageBase with MLWritable { + + override val uid: String = callFromPython(proxy.getUid) + + private[python] override def getProxy = this.proxy + + override def transformSchema(schema: StructType): StructType = { + callFromPython(proxy.transformSchema(schema)) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + callFromPython(proxy.transform(dataset)) + } + + override def copy(extra: ParamMap): PythonTransformer = { + this.proxy = callFromPython(proxy.copy(extra)) + this + } + + override def write: MLWriter = new PythonTransformer.PythonTransformerWriter(this) + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + proxy = PythonStageSerializer.deserialize(bytes) + } +} + +object PythonTransformer extends MLReadable[PythonTransformer] { + + override def read: MLReader[PythonTransformer] = new PythonTransformerReader + + override def load(path: String): PythonTransformer = super.load(path) + + private[python] class PythonTransformerWriter(instance: PythonTransformer) + extends PythonStage.Writer[PythonTransformer](instance) + + private class PythonTransformerReader extends PythonStage.Reader[PythonTransformer] +} + +/** + * Common functions for Python PipelineStage. + */ +trait PythonStageBase { + + private[python] def getProxy: PythonStageWrapper + + private[python] def callFromPython[R](result: R): R = { + val failure = getProxy.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + result + } + + /** + * Get serialized Python PipelineStage. + */ + private[python] def getPythonStage: Array[Byte] = { + callFromPython(getProxy.getStage) + } + + /** + * Get the stage's fully qualified class name in PySpark. + */ + private[python] def getPythonClassName: String = { + callFromPython(getProxy.getClassName) + } + + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + val bytes = PythonStageSerializer.serialize(getProxy) + out.writeInt(bytes.length) + out.write(bytes) + } +} + +private[python] object PythonStage { + /** + * Helper functions due to Py4J error of reader/serializer does not exist in the JVM. + */ + def registerReader(r: PythonStageReader): Unit = { + PythonStageWrapper.registerReader(r) + } + + def registerSerializer(ser: PythonStageSerializer): Unit = { + PythonStageSerializer.register(ser) + } + + /** + * Helper functions for Reader/Writer in Python Stages. + */ + private[python] class Writer[S <: PipelineStage with PythonStageBase](instance: S) + extends MLWriter { + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "pyClass" -> instance.getPythonClassName + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val pyDir = new Path(path, s"pyStage-${instance.uid}").toString + instance.callFromPython(instance.getProxy.save(pyDir)) + } + } + + private[python] class Reader[S <: PipelineStage with PythonStageBase: ClassTag] + extends MLReader[S] { + private val className = classTag[S].runtimeClass.getName + override def load(path: String): S = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val pyClass = (metadata.metadata \ "pyClass").extract[String] + val pyDir = new Path(path, s"pyStage-${metadata.uid}").toString + val proxy = PythonStageWrapper.load(pyDir, pyClass) + classTag[S].runtimeClass.getConstructor(classOf[PythonStageWrapper]) + .newInstance(proxy).asInstanceOf[S] + } + } +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index aec0215b4094..d578b875744f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -24,6 +24,7 @@ import threading from threading import RLock from tempfile import NamedTemporaryFile +from py4j.java_gateway import JavaObject from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -240,6 +241,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ + with SparkContext._lock: if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() @@ -262,6 +264,29 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + cls.__ensure_callback_server() + + @classmethod + def __ensure_callback_server(cls): + gw = SparkContext._gateway + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ade4864e1d78..492d58d78a88 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -387,6 +387,16 @@ def copy(self, extra=None): that._paramMap = {} return self._copyValues(that, extra) + @since("2.1.0") + def transformSchema(self, schema): + """ + Transform input schema to output schema. + + :param schema: + :return: schema + """ + return schema + def _shouldOwn(self, param): """ Validates that the input param belongs to this Params instance. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a48f4bb2ad1b..8843179d7bd3 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -15,21 +15,77 @@ # limitations under the License. # -import sys - -if sys.version > '3': - basestring = str +import re from pyspark import since, keyword_only, SparkContext from pyspark.ml import Estimator, Model, Transformer -from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc +from pyspark.ml.param import Param, Params +from pyspark.ml.util import JavaMLReadable, JavaMLWritable, StageWrapper, _get_class +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams +from pyspark.mllib.common import _py2java, _java2py +from pyspark.serializers import CloudPickleSerializer + + +class PipelineWrapper(object): + """ + A pipeline wrapper for :py:class:`Pipeline` and :py:class:`PipelineModel` supports transferring + the array of pipeline stages between Python side and Scala side. + """ + + def _transfer_stages_to_java(self, py_stages): + """ + Transforms the parameter of Python stages to a list of Java stages. + For pure Python stage, we use its Java wrapper as proxy. + """ + + sc = SparkContext._active_spark_context + + def __transfer_stage_to_java(py_stage): + if isinstance(py_stage, JavaParams): + py_stage._transfer_params_to_java() + return py_stage._java_obj + else: + wrapper = StageWrapper(sc, py_stage) + if isinstance(py_stage, Estimator): + jstage = sc._jvm.\ + org.apache.spark.ml.api.python.PythonEstimator(wrapper) + elif isinstance(py_stage, Model): + jstage = sc._jvm.\ + org.apache.spark.ml.api.python.PythonModel(wrapper) + elif isinstance(py_stage, Transformer): + jstage = sc._jvm.\ + org.apache.spark.ml.api.python.PythonTransformer(wrapper) + else: + raise Exception( + "Unimplemented Scala wrapper for Python type %s" % type(py_stage)) + return jstage + + return [__transfer_stage_to_java(stage) for stage in py_stages] + + def _transfer_stages_from_java(self, java_sc, java_stages): + """ + Transforms the parameter Python stages from a list of Java stages. + """ + + def __transfer_stage_from_java(java_stage): + if re.match("org\.apache\.spark\.ml\.api\.python\.Python*", + java_stage.getClass().getName()): + return CloudPickleSerializer().loads(bytes(java_stage.getPythonStage())) + stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") + # Generate a default new instance from the stage_name class. + py_stage = _get_class(stage_name)() + # Load information from java_stage to the instance. + py_stage._java_obj = java_stage + py_stage._resetUid(_java2py(java_sc, java_stage.uid())) + py_stage._transfer_params_from_java() + return py_stage + + return [__transfer_stage_from_java(stage) for stage in java_stages] @inherit_doc -class Pipeline(Estimator, MLReadable, MLWritable): +class Pipeline(PipelineWrapper, JavaEstimator, JavaMLReadable, JavaMLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -42,11 +98,69 @@ class Pipeline(Estimator, MLReadable, MLWritable): stage. If a stage is a :py:class:`Transformer`, its :py:meth:`Transformer.transform` method will be called to produce the dataset for the next stage. The fitted model from a - :py:class:`Pipeline` is a :py:class:`PipelineModel`, which + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which consists of fitted models and transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as an identity transformer. + >>> from pyspark.ml.feature import HashingTF + >>> from pyspark.ml.feature import PCA + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features") + >>> pl = Pipeline(stages=[hashingTF, pca]) + >>> model = pl.fit(df) + >>> transformed = model.transform(df) + >>> transformed.head().words == ["a", "b", "c"] + True + >>> transformed.head().features + SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0}) + >>> transformed.head().pca_features + DenseVector([-1.0, 0.5774]) + >>> import tempfile + >>> path = tempfile.mkdtemp() + >>> featurePath = path + "/feature-transformer" + >>> pl.save(featurePath) + >>> loadedPipeline = Pipeline.load(featurePath) + >>> loadedPipeline.uid == pl.uid + True + >>> len(loadedPipeline.getStages()) + 2 + >>> [loadedHT, loadedPCA] = loadedPipeline.getStages() + >>> type(loadedHT) + + >>> type(loadedPCA) + + >>> loadedHT.uid == hashingTF.uid + True + >>> param = loadedHT.getParam("numFeatures") + >>> loadedHT.getOrDefault(param) == hashingTF.getOrDefault(param) + True + >>> loadedPCA.uid == pca.uid + True + >>> loadedPCA.getK() == pca.getK() + True + >>> modelPath = path + "/feature-model" + >>> model.save(modelPath) + >>> loadedModel = PipelineModel.load(modelPath) + >>> [hashingTFinModel, pcaModel] = model.stages + >>> [loadedHTinModel, loadedPCAModel] = loadedModel.stages + >>> hashingTFinModel.uid == loadedHTinModel.uid + True + >>> hashingTFinModel.getOrDefault(param) == loadedHTinModel.getOrDefault(param) + True + >>> pcaModel.uid == loadedPCAModel.uid + True + >>> pcaModel.pc == loadedPCAModel.pc + True + >>> pcaModel.explainedVariance == loadedPCAModel.explainedVariance + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + .. versionadded:: 1.3.0 """ @@ -57,12 +171,23 @@ def __init__(self, stages=None): """ __init__(self, stages=None) """ - if stages is None: - stages = [] super(Pipeline, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + @keyword_only + @since("1.3.0") + def setParams(self, stages=None): + """ + setParams(self, stages=None) + Sets params for Pipeline. + """ + if stages is None: + stages = [] + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + @since("1.3.0") def setStages(self, value): """ @@ -71,7 +196,9 @@ def setStages(self, value): :param value: a list of transformers or estimators :return: the pipeline instance """ - return self._set(stages=value) + self._paramMap[self.stages] = value + self._java_stages = self._transfer_stages_to_java(value) + return self @since("1.3.0") def getStages(self): @@ -81,179 +208,81 @@ def getStages(self): if self.stages in self._paramMap: return self._paramMap[self.stages] - @keyword_only - @since("1.3.0") - def setParams(self, stages=None): - """ - setParams(self, stages=None) - Sets params for Pipeline. - """ - if stages is None: - stages = [] - kwargs = self.setParams._input_kwargs - return self._set(**kwargs) - - def _fit(self, dataset): - stages = self.getStages() - for stage in stages: - if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): - raise TypeError( - "Cannot recognize a pipeline stage of type %s." % type(stage)) - indexOfLastEstimator = -1 - for i, stage in enumerate(stages): - if isinstance(stage, Estimator): - indexOfLastEstimator = i - transformers = [] - for i, stage in enumerate(stages): - if i <= indexOfLastEstimator: - if isinstance(stage, Transformer): - transformers.append(stage) - dataset = stage.transform(dataset) - else: # must be an Estimator - model = stage.fit(dataset) - transformers.append(model) - if i < indexOfLastEstimator: - dataset = model.transform(dataset) - else: - transformers.append(stage) - return PipelineModel(transformers) - - @since("1.4.0") - def copy(self, extra=None): + def _transfer_params_to_java(self): """ - Creates a copy of this instance. - - :param extra: extra parameters - :returns: new instance + Transforms the parameter stages to Java stages. """ - if extra is None: - extra = dict() - that = Params.copy(self, extra) - stages = [stage.copy(extra) for stage in that.getStages()] - return that.setStages(stages) - - @since("2.0.0") - def write(self): - """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + paramMap = self.extractParamMap() + if self.stages not in paramMap: + return + param = self.stages + value = paramMap[param] + + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + java_param = self._java_obj.getParam(param.name) + gateway = SparkContext._gateway + jvm = SparkContext._jvm + stageArray = gateway.new_array(jvm.org.apache.spark.ml.PipelineStage, len(value)) - @since("2.0.0") - def save(self, path): - """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - self.write().save(path) + for idx, java_stage in enumerate(self._transfer_stages_to_java(self.getStages())): + stageArray[idx] = java_stage - @classmethod - @since("2.0.0") - def read(cls): - """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + java_value = _py2java(sc, stageArray) + self._java_obj.set(java_param.w(java_value)) - @classmethod - def _from_java(cls, java_stage): - """ - Given a Java Pipeline, create and return a Python wrapper of it. - Used for ML persistence. + def _transfer_params_from_java(self): """ - # Create a new instance of this stage. - py_stage = cls() - # Load information from java_stage to the instance. - py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] - py_stage.setStages(py_stages) - py_stage._resetUid(java_stage.uid()) - return py_stage - - def _to_java(self): + Transforms the parameter stages from the companion Java object. """ - Transfer this instance to a Java Pipeline. Used for ML persistence. - - :return: Java object equivalent to this instance. - """ - - gateway = SparkContext._gateway - cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage - java_stages = gateway.new_array(cls, len(self.getStages())) - for idx, stage in enumerate(self.getStages()): - java_stages[idx] = stage._to_java() + sc = SparkContext._active_spark_context + assert self._java_obj.hasParam(self.stages.name) + java_param = self._java_obj.getParam(self.stages.name) + if self._java_obj.isDefined(java_param): + java_stages = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[self.stages] = self._transfer_stages_from_java(sc, java_stages) + else: + self._paramMap[self.stages] = [] - _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) - _java_obj.setStages(java_stages) - - return _java_obj + def _create_model(self, java_model): + return PipelineModel(java_model) @inherit_doc -class PipelineModel(Model, MLReadable, MLWritable): +class PipelineModel(PipelineWrapper, JavaModel, JavaMLReadable, JavaMLWritable): """ Represents a compiled pipeline with transformers and fitted models. .. versionadded:: 1.3.0 """ - def __init__(self, stages): - super(PipelineModel, self).__init__() - self.stages = stages - - def _transform(self, dataset): - for t in self.stages: - dataset = t.transform(dataset) - return dataset - - @since("1.4.0") - def copy(self, extra=None): - """ - Creates a copy of this instance. - - :param extra: extra parameters - :returns: new instance - """ - if extra is None: - extra = dict() - stages = [stage.copy(extra) for stage in self.stages] - return PipelineModel(stages) - + @property @since("2.0.0") - def write(self): - """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) - - @since("2.0.0") - def save(self, path): - """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - self.write().save(path) - - @classmethod - @since("2.0.0") - def read(cls): - """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) - - @classmethod - def _from_java(cls, java_stage): - """ - Given a Java PipelineModel, create and return a Python wrapper of it. - Used for ML persistence. + def stages(self): """ - # Load information from java_stage to the instance. - py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] - # Create a new instance of this stage. - py_stage = cls(py_stages) - py_stage._resetUid(java_stage.uid()) - return py_stage - - def _to_java(self): - """ - Transfer this instance to a Java PipelineModel. Used for ML persistence. - - :return: Java object equivalent to this instance. + Returns stages of the pipeline model. """ - - gateway = SparkContext._gateway - cls = SparkContext._jvm.org.apache.spark.ml.Transformer - java_stages = gateway.new_array(cls, len(self.stages)) - for idx, stage in enumerate(self.stages): - java_stages[idx] = stage._to_java() - - _java_obj =\ - JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) - - return _java_obj + sc = SparkContext._active_spark_context + java_stages = self._call_java("stages") + py_stages = self._transfer_stages_from_java(sc, java_stages) + return py_stages + + +if __name__ == "__main__": + import doctest + import pyspark.ml + import pyspark.ml.feature + from pyspark.sql import SQLContext + globs = pyspark.ml.__dict__.copy() + globs_feature = pyspark.ml.feature.__dict__.copy() + globs.update(globs_feature) + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.pipeline tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 981ed9dda042..dd92ac288f11 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -55,7 +55,7 @@ from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasOutputCol, HasSeed from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ GeneralizedLinearRegression @@ -65,6 +65,7 @@ from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand +from pyspark.sql.types import StructField, DoubleType from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -93,12 +94,6 @@ def tearDownClass(cls): cls.spark.stop() -class MockDataset(DataFrame): - - def __init__(self): - self.index = 0 - - class HasFake(Params): def __init__(self): @@ -109,35 +104,6 @@ def getFake(self): return self.getOrDefault(self.fake) -class MockTransformer(Transformer, HasFake): - - def __init__(self): - super(MockTransformer, self).__init__() - self.dataset_index = None - - def _transform(self, dataset): - self.dataset_index = dataset.index - dataset.index += 1 - return dataset - - -class MockEstimator(Estimator, HasFake): - - def __init__(self): - super(MockEstimator, self).__init__() - self.dataset_index = None - - def _fit(self, dataset): - self.dataset_index = dataset.index - model = MockModel() - self._copyValues(model) - return model - - -class MockModel(MockTransformer, Model, HasFake): - pass - - class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. @@ -204,31 +170,71 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) -class PipelineTests(PySparkTestCase): +class MockTransformer(Transformer, HasInputCol, HasOutputCol): + factor = Param(Params._dummy(), "factor", "factor", typeConverter=TypeConverters.toFloat) + + def __init__(self): + super(MockTransformer, self).__init__() + self._setDefault(factor=1) + + def transformSchema(self, schema): + outSchema = StructField(self.getOutputCol(), DoubleType(), True) + return schema.add(outSchema) + + def _transform(self, dataset): + inc = self.getInputCol() + ouc = self.getOutputCol() + factor = self.getOrDefault(self.factor) + return dataset.withColumn(ouc, dataset[inc] + factor) + + +class MockModel(MockTransformer): + def __init__(self): + super(MockModel, self).__init__() + + +class MockEstimator(Estimator, HasInputCol, HasOutputCol): + + def __init__(self): + super(MockEstimator, self).__init__() + + def transformSchema(self, schema): + outSchema = StructField(self.getOutputCol(), DoubleType(), True) + return schema.add(outSchema) + + def _fit(self, dataset): + cnt = dataset.count() + model = MockModel()._set(factor=cnt)._resetUid(self.uid) + return self._copyValues(model) + + +class PipelineTests(SparkSessionTestCase): def test_pipeline(self): - dataset = MockDataset() + data = self.spark.createDataFrame([(1,), (2,), (3,), (4,)], ["number"]) + + transformer0 = MockTransformer() + transformer0.setInputCol("number").setOutputCol("result0") estimator0 = MockEstimator() + estimator0.setInputCol(transformer0.getOutputCol()).setOutputCol("result1") transformer1 = MockTransformer() - estimator2 = MockEstimator() - transformer3 = MockTransformer() - pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) - pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - model0, transformer1, model2, transformer3 = pipeline_model.stages - self.assertEqual(0, model0.dataset_index) - self.assertEqual(0, model0.getFake()) - self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.getFake()) - self.assertEqual(2, dataset.index) - self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") - self.assertIsNone(transformer3.dataset_index, - "The last transformer shouldn't be called in fit.") - dataset = pipeline_model.transform(dataset) - self.assertEqual(2, model0.dataset_index) - self.assertEqual(3, transformer1.dataset_index) - self.assertEqual(4, model2.dataset_index) - self.assertEqual(5, transformer3.dataset_index) - self.assertEqual(6, dataset.index) + transformer1.setInputCol(estimator0.getOutputCol()).setOutputCol("result2")._set(factor=2) + + pipeline = Pipeline(stages=[transformer0, estimator0, transformer1]) + + model = pipeline.fit(data) + + self.assertEqual(len(model.stages), 3) + self.assertIsInstance(model.stages[0], MockTransformer) + self.assertEqual(model.stages[0].uid, transformer0.uid) + self.assertIsInstance(model.stages[1], MockModel) + self.assertEqual(model.stages[1].uid, estimator0.uid) + self.assertIsInstance(model.stages[2], MockTransformer) + self.assertEqual(model.stages[2].uid, transformer1.uid) + + result = model.transform(data).select(transformer1.getOutputCol()).collect() + self.assertListEqual( + result, [Row(result2=8), Row(result2=9), Row(result2=10), Row(result2=11)]) class TestParams(HasMaxIter, HasInputCol, HasSeed): @@ -758,8 +764,16 @@ def _compare_params(self, m1, m2, param): paramValue2 = m2.getOrDefault(m2.getParam(param.name)) if isinstance(paramValue1, Params): self._compare_pipelines(paramValue1, paramValue2) - else: - self.assertEqual(paramValue1, paramValue2) # for general types param + elif isinstance(paramValue1, list): + if not paramValue1: + self.assertEqual(paramValue2, []) + elif isinstance(paramValue1[0], Params): + for p1, p2 in zip(paramValue1, paramValue2): + self._compare_pipelines(p1, p2) + else: + self.assertListEqual(paramValue1, paramValue2) + else: # for general types param + self.assertEqual(paramValue1, paramValue2) # Assert parents are equal self.assertEqual(param.parent, m2.getParam(param.name).parent) else: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 4a31a298096f..a411d9aa42f5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -22,8 +22,14 @@ basestring = str unicode = str -from pyspark import SparkContext, since +import traceback + +from pyspark import SparkContext from pyspark.ml.common import inherit_doc +from pyspark.serializers import CloudPickleSerializer +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import _parse_datatype_json_string +from pyspark.sql import SQLContext def _jvm(): @@ -38,6 +44,18 @@ def _jvm(): raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") +def _get_class(clazz): + """ + Loads Python class from its name. + """ + parts = clazz.split('.') + module = ".".join(parts[:-1]) + m = __import__(module) + for comp in parts[1:]: + m = getattr(m, comp) + return m + + class Identifiable(object): """ Object with a unique ID. @@ -238,3 +256,159 @@ class JavaMLReadable(MLReadable): def read(cls): """Returns an MLReader instance for this class.""" return JavaMLReader(cls) + + +class StageWrapper(object): + """ + This class wraps a pure Python stage, allowing it to be called from Java via Py4J's + callback server, making it as-like a Java side PipelineStage. + """ + + def __init__(self, sc, stage): + self.sc = sc + self.sql_ctx = SQLContext.getOrCreate(self.sc) + self.stage = stage + self.failure = None + reader = StageReader(self.sc) + self.sc._gateway.jvm.\ + org.apache.spark.ml.api.python.PythonStage.registerReader(reader) + + def getUid(self): + self.failure = None + try: + return self.stage.uid + except: + self.failure = traceback.format_exc() + + def copy(self, extra): + self.failure = None + try: + self.stage = self.stage.copy(extra) + return self + except: + self.failure = traceback.format_exc() + + def transformSchema(self, jschema): + """ + Transform Java schema with transformSchema in pure Python stages. + """ + self.failure = None + try: + schema = _parse_datatype_json_string(jschema.json()) + converted = self.stage.transformSchema(schema) + return _jvm().org.apache.spark.sql.types.StructType.fromJson(converted.json()) + except: + self.failure = traceback.format_exc() + + def getStage(self): + self.failure = None + try: + return bytearray(CloudPickleSerializer().dumps(self.stage)) + except: + self.failure = traceback.format_exc() + + def getClassName(self): + self.failure = None + try: + cls = self.stage.__class__ + return cls.__module__ + "." + cls.__name__ + except: + self.failure = traceback.format_exc() + + def fit(self, jdf): + self.failure = None + try: + df = DataFrame(jdf, self.sql_ctx) if jdf else None + m = self.stage.fit(df) + if m: + return StageWrapper(self.sc, m) + except: + self.failure = traceback.format_exc() + + def transform(self, jdf): + self.failure = None + try: + df = DataFrame(jdf, self.sql_ctx) if jdf else None + r = self.stage.transform(df) + if r: + return r._jdf + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure + + def save(self, path): + self.failure = None + try: + self.stage.save(path) + except: + self.failure = traceback.format_exc() + + def __repr__(self): + return "StageWrapper(%s)" % self.stage + + class Java: + implements = ['org.apache.spark.ml.api.python.PythonStageWrapper'] + + +class StageReader(object): + """ + Reader to load Python stages. + """ + def __init__(self, sc): + self.failure = None + self.sc = sc + + def getLastFailure(self): + return self.failure + + def load(self, path, clazz): + self.failure = None + try: + cls = _get_class(clazz) + transformer = cls.load(path) + return StageWrapper(self.sc, transformer) + except: + self.failure = traceback.format_exc() + + class Java: + implements = ['org.apache.spark.ml.api.python.PythonStageReader'] + + +class StageSerializer(object): + """ + This class implements a serializer for PythonStageWrapper Java objects. + """ + def __init__(self, sc, serializer): + self.sc = sc + self.serializer = serializer + self.gateway = self.sc._gateway + self.gateway.jvm\ + .org.apache.spark.ml.api.python.PythonPipelineStage.registerSerializer(self) + self.failure = None + + def dumps(self, id): + self.failure = None + try: + wrapper = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps(wrapper.stage)) + except: + self.failure = traceback.format_exc() + + def loads(self, data): + self.failure = None + try: + stage = self.serializer.loads(bytes(data)) + return StageWrapper(self.sc, stage) + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure + + def __repr__(self): + return "StageSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.ml.api.python.PythonStageSerializer'] diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 25c44b7533c7..d39d598d9b91 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark.sql import DataFrame from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.param import Params -from pyspark.ml.util import _jvm +from pyspark.ml.util import _jvm, _get_class from pyspark.ml.common import inherit_doc, _java2py, _py2java @@ -154,19 +154,9 @@ def _from_java(java_stage): Meta-algorithms such as Pipeline should override this method as a classmethod. """ - def __get_class(clazz): - """ - Loads Python class from its name. - """ - parts = clazz.split('.') - module = ".".join(parts[:-1]) - m = __import__(module) - for comp in parts[1:]: - m = getattr(m, comp) - return m stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. - py_type = __get_class(stage_name) + py_type = _get_class(stage_name) if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type()