Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5eb0245
add fake transformer with pure Python
yinxusen Jun 7, 2016
06a6545
add an action in pure python transformer
yinxusen Jun 8, 2016
70c0945
Merge branch 'master' into SPARK-15574
yinxusen Jun 9, 2016
73cbcc7
Merge branch 'master' into SPARK-15574
yinxusen Jun 10, 2016
4418ae6
add transformer wrapper
yinxusen Jun 10, 2016
72787fe
update python transformer
yinxusen Jun 13, 2016
f5094ad
add python transformer wrapper
yinxusen Jun 13, 2016
9e2632e
add active spark context
yinxusen Jun 13, 2016
8cff5bc
split uid
yinxusen Jun 13, 2016
6bde27f
fix more
yinxusen Jun 13, 2016
5fdeac8
start callback server
yinxusen Jun 13, 2016
32cf6e4
fix order
yinxusen Jun 13, 2016
2079a03
fix order
yinxusen Jun 14, 2016
e27da47
Merge branch 'master' into SPARK-15574
yinxusen Jun 14, 2016
d499d0e
add pipeline as prototype
yinxusen Jun 15, 2016
371485a
add support for pure python transformer
yinxusen Jun 15, 2016
0e026e5
fix error of classmethod
yinxusen Jun 15, 2016
3655ab3
add pure_pipeline as test
yinxusen Jun 15, 2016
7c7f684
add debug info
yinxusen Jun 15, 2016
96f7361
add transformSchema
yinxusen Jun 15, 2016
954bcb8
add docstring
yinxusen Jun 15, 2016
c41d4dc
convert json to string
yinxusen Jun 15, 2016
c289be4
change API
yinxusen Jun 15, 2016
72facb2
add another pure python transformer
yinxusen Jun 15, 2016
57a9bd1
add to pipeline
yinxusen Jun 15, 2016
3536503
fix fromJava
yinxusen Jun 15, 2016
fe42e4f
fix bugs
yinxusen Jun 15, 2016
9f77900
add getTransformer
yinxusen Jun 16, 2016
0047cf2
add method into java wrapper
yinxusen Jun 16, 2016
752570d
add ser/de for transformer
yinxusen Jun 16, 2016
7a1b564
get failure from get transformer
yinxusen Jun 16, 2016
1c05c66
for debug purpose
yinxusen Jun 16, 2016
3346cae
change pickle to cloud pickle
yinxusen Jun 16, 2016
ff4a1b1
use cloud pickle
yinxusen Jun 16, 2016
67e384e
add save/load for python wrapper
yinxusen Jun 17, 2016
710cbcd
fix class name
yinxusen Jun 17, 2016
e63d700
change sc for sqlsc
yinxusen Jun 20, 2016
5db5466
add helper function
yinxusen Jun 20, 2016
6c06125
fix python side
yinxusen Jun 20, 2016
9ce5fbd
change MLRead/Write to Java version
yinxusen Jun 20, 2016
7a99f88
catch error from python side
yinxusen Jun 20, 2016
494d4b8
refine code in PythonPipelineStage
yinxusen Jun 20, 2016
d2c1113
fix python util
yinxusen Jun 20, 2016
db544c1
fix error
yinxusen Jun 20, 2016
dd95649
simplify the transform
yinxusen Jun 20, 2016
191f5ce
add descriptors
yinxusen Jun 20, 2016
62256ea
change limitations
yinxusen Jun 20, 2016
7eb3d0f
add new test
yinxusen Jun 21, 2016
aab8578
delete helper files
yinxusen Jun 21, 2016
3ab02a5
merge with master
yinxusen Jun 21, 2016
7c38eef
fix errors
yinxusen Jun 21, 2016
fbef5e8
add wrapper for estimator and model
yinxusen Jun 22, 2016
451fb4b
remove unnecessary get_class
yinxusen Jun 22, 2016
2b9697d
add getStage for estimator and model
yinxusen Jun 22, 2016
22966a3
move transformSchema to Params
yinxusen Jun 23, 2016
c16d711
remote PurePythonXXX
yinxusen Jun 23, 2016
4cd1034
add new mock transform estimator
yinxusen Jun 23, 2016
0db67ec
reset uid for model
yinxusen Jun 23, 2016
001a605
refine test result
yinxusen Jun 23, 2016
d377ac6
rename helper classes
yinxusen Jun 23, 2016
f4ec73d
add save/load for estimator and model
yinxusen Jun 23, 2016
8d23bac
Merge branch 'master' into SPARK-15574
yinxusen Jun 23, 2016
1ccdea5
fix error
yinxusen Jun 23, 2016
4db921e
fix error object name
yinxusen Jun 23, 2016
b8ddcdb
add uid as part of saving path
yinxusen Jun 23, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def __hash__(self):
"pyspark.ml.classification",
"pyspark.ml.clustering",
"pyspark.ml.linalg.__init__",
"pyspark.ml.pipeline",
"pyspark.ml.recommendation",
"pyspark.ml.regression",
"pyspark.ml.tuning",
Expand Down
376 changes: 376 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/api/python/PythonStage.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.api.python

import java.io.{ObjectInputStream, ObjectOutputStream}
import java.lang.reflect.Proxy

import scala.reflect._
import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path
import org.json4s._

import org.apache.spark.SparkException
import org.apache.spark.ml.{Estimator, Model, PipelineStage, Transformer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Wrapper of PipelineStage (Estimator/Model/Transformer) written in pure Python, which
* implementation is in PySpark. See pyspark.ml.util.StageWrapper
*/
private[python] trait PythonStageWrapper {

def getUid: String

def fit(dataset: Dataset[_]): PythonStageWrapper

def transform(dataset: Dataset[_]): DataFrame

def transformSchema(schema: StructType): StructType

def getStage: Array[Byte]

def getClassName: String

def save(path: String): Unit

def copy(extra: ParamMap): PythonStageWrapper

/**
* Get the failure in PySpark, if any.
* @return the failure message if there was a failure, or `null` if there was no failure.
*/
def getLastFailure: String
}

/**
* ML Reader for Python PipelineStages. The implementation of the reader is in Python, which is
* registered here the moment we creating a new PythonStageWrapper.
*/
private[python] object PythonStageWrapper {
private var reader: PythonStageReader = _

/**
* Register Python stage reader to load PySpark PipelineStages.
*/
def registerReader(r: PythonStageReader): Unit = {
reader = r
}

/**
* Load a Python PipelineStage given its path and class name.
*/
def load(path: String, clazz: String): PythonStageWrapper = {
require(reader != null, "Python reader has not been registered.")
callLoadFromPython(path, clazz)
}

private def callLoadFromPython(path: String, clazz: String): PythonStageWrapper = {
val result = reader.load(path, clazz)
val failure = reader.getLastFailure
if (failure != null) {
throw new SparkException("An exception was raised by Python:\n" + failure)
}
result
}
}

/**
* Reader to load a pure Python PipelineStage. Its implementation is in PySpark.
* See pyspark.ml.util.StageReader
*/
private[python] trait PythonStageReader {

def getLastFailure: String

def load(path: String, clazz: String): PythonStageWrapper
}

/**
* Serializer of a pure Python PipelineStage. Its implementation is in Pyspark.
* See pyspark.ml.util.StageSerializer
*/
private[python] trait PythonStageSerializer {

def dumps(id: String): Array[Byte]

def loads(bytes: Array[Byte]): PythonStageWrapper

def getLastFailure: String
}

/**
* Helpers for PythonStageSerializer.
*/
private[python] object PythonStageSerializer {

/**
* A serializer in Python, used to serialize PythonStageWrapper.
*/
private var serializer: PythonStageSerializer = _

/*
* Register a serializer from Python, should be called during initialization
*/
def register(ser: PythonStageSerializer): Unit = synchronized {
serializer = ser
}

def serialize(wrapper: PythonStageWrapper): Array[Byte] = synchronized {
require(serializer != null, "Serializer has not been registered!")
// get the id of PythonTransformFunction in py4j
val h = Proxy.getInvocationHandler(wrapper.asInstanceOf[Proxy])
val f = h.getClass.getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
val results = serializer.dumps(id)
val failure = serializer.getLastFailure
if (failure != null) {
throw new SparkException("An exception was raised by Python:\n" + failure)
}
results
}

def deserialize(bytes: Array[Byte]): PythonStageWrapper = synchronized {
require(serializer != null, "Serializer has not been registered!")
val wrapper = serializer.loads(bytes)
val failure = serializer.getLastFailure
if (failure != null) {
throw new SparkException("An exception was raised by Python:\n" + failure)
}
wrapper
}
}

/**
* A proxy estimator for all PySpark estimator written in pure Python.
*/
class PythonEstimator(@transient private var proxy: PythonStageWrapper)
extends Estimator[PythonModel] with PythonStageBase with MLWritable {

override val uid: String = proxy.getUid

private[python] override def getProxy = this.proxy

override def fit(dataset: Dataset[_]): PythonModel = {
val modelWrapper = callFromPython(proxy.fit(dataset))
new PythonModel(modelWrapper)
}

override def copy(extra: ParamMap): Estimator[PythonModel] = {
this.proxy = callFromPython(proxy.copy(extra))
this
}

override def transformSchema(schema: StructType): StructType = {
callFromPython(proxy.transformSchema(schema))
}

override def write: MLWriter = new PythonEstimator.PythonEstimatorWriter(this)

private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
proxy = PythonStageSerializer.deserialize(bytes)
}
}

object PythonEstimator extends MLReadable[PythonEstimator] {

override def read: MLReader[PythonEstimator] = new PythonEstimatorReader

override def load(path: String): PythonEstimator = super.load(path)

private[python] class PythonEstimatorWriter(instance: PythonEstimator)
extends PythonStage.Writer[PythonEstimator](instance)

private class PythonEstimatorReader extends PythonStage.Reader[PythonEstimator]
}

/**
* A proxy model of all PySpark Model written in pure Python.
*/
class PythonModel(@transient private var proxy: PythonStageWrapper)
extends Model[PythonModel] with PythonStageBase with MLWritable {

override val uid: String = proxy.getUid

private[python] override def getProxy = this.proxy

override def copy(extra: ParamMap): PythonModel = {
this.proxy = callFromPython(proxy.copy(extra))
this
}

override def transform(dataset: Dataset[_]): DataFrame = {
callFromPython(proxy.transform(dataset))
}

override def transformSchema(schema: StructType): StructType = {
callFromPython(proxy.transformSchema(schema))
}

override def write: MLWriter = new PythonModel.PythonModelWriter(this)

private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
proxy = PythonStageSerializer.deserialize(bytes)
}
}

object PythonModel extends MLReadable[PythonModel] {

override def read: MLReader[PythonModel] = new PythonModelReader

override def load(path: String): PythonModel = super.load(path)

private[python] class PythonModelWriter(instance: PythonModel)
extends PythonStage.Writer[PythonModel](instance)

private class PythonModelReader extends PythonStage.Reader[PythonModel]
}

/**
* A proxy transformer for all PySpark transformers written in pure Python.
*/
class PythonTransformer(@transient private var proxy: PythonStageWrapper)
extends Transformer with PythonStageBase with MLWritable {

override val uid: String = callFromPython(proxy.getUid)

private[python] override def getProxy = this.proxy

override def transformSchema(schema: StructType): StructType = {
callFromPython(proxy.transformSchema(schema))
}

override def transform(dataset: Dataset[_]): DataFrame = {
callFromPython(proxy.transform(dataset))
}

override def copy(extra: ParamMap): PythonTransformer = {
this.proxy = callFromPython(proxy.copy(extra))
this
}

override def write: MLWriter = new PythonTransformer.PythonTransformerWriter(this)

private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
proxy = PythonStageSerializer.deserialize(bytes)
}
}

object PythonTransformer extends MLReadable[PythonTransformer] {

override def read: MLReader[PythonTransformer] = new PythonTransformerReader

override def load(path: String): PythonTransformer = super.load(path)

private[python] class PythonTransformerWriter(instance: PythonTransformer)
extends PythonStage.Writer[PythonTransformer](instance)

private class PythonTransformerReader extends PythonStage.Reader[PythonTransformer]
}

/**
* Common functions for Python PipelineStage.
*/
trait PythonStageBase {

private[python] def getProxy: PythonStageWrapper

private[python] def callFromPython[R](result: R): R = {
val failure = getProxy.getLastFailure
if (failure != null) {
throw new SparkException("An exception was raised by Python:\n" + failure)
}
result
}

/**
* Get serialized Python PipelineStage.
*/
private[python] def getPythonStage: Array[Byte] = {
callFromPython(getProxy.getStage)
}

/**
* Get the stage's fully qualified class name in PySpark.
*/
private[python] def getPythonClassName: String = {
callFromPython(getProxy.getClassName)
}

private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val bytes = PythonStageSerializer.serialize(getProxy)
out.writeInt(bytes.length)
out.write(bytes)
}
}

private[python] object PythonStage {
/**
* Helper functions due to Py4J error of reader/serializer does not exist in the JVM.
*/
def registerReader(r: PythonStageReader): Unit = {
PythonStageWrapper.registerReader(r)
}

def registerSerializer(ser: PythonStageSerializer): Unit = {
PythonStageSerializer.register(ser)
}

/**
* Helper functions for Reader/Writer in Python Stages.
*/
private[python] class Writer[S <: PipelineStage with PythonStageBase](instance: S)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "pyClass" -> instance.getPythonClassName
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val pyDir = new Path(path, s"pyStage-${instance.uid}").toString
instance.callFromPython(instance.getProxy.save(pyDir))
}
}

private[python] class Reader[S <: PipelineStage with PythonStageBase: ClassTag]
extends MLReader[S] {
private val className = classTag[S].runtimeClass.getName
override def load(path: String): S = {
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val pyClass = (metadata.metadata \ "pyClass").extract[String]
val pyDir = new Path(path, s"pyStage-${metadata.uid}").toString
val proxy = PythonStageWrapper.load(pyDir, pyClass)
classTag[S].runtimeClass.getConstructor(classOf[PythonStageWrapper])
.newInstance(proxy).asInstanceOf[S]
}
}
}
Loading