-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12042] Python API for mllib.stat.test.StreamingTest #11374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6867a89
079a873
f70d7aa
770703b
ff9932b
e4e8d5e
615fbbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| """ | ||
| Create a DStream that contains several RDDs to show the StreamingTest of PySpark. | ||
| """ | ||
| import time | ||
| import tempfile | ||
| from shutil import rmtree | ||
|
|
||
| from pyspark import SparkContext | ||
| from pyspark.streaming import StreamingContext | ||
| from pyspark.mllib.stat.test import BinarySample, StreamingTest | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't include newline here |
||
| sc = SparkContext(appName="PythonStreamingTestExample") | ||
| ssc = StreamingContext(sc, 1) | ||
|
|
||
| checkpoint_path = tempfile.mkdtemp() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? |
||
| ssc.checkpoint(checkpoint_path) | ||
|
|
||
| # Create the queue through which RDDs can be pushed to a QueueInputDStream. | ||
| rdd_queue = [] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
| for i in range(5): | ||
| rdd_queue += [ssc.sparkContext.parallelize( | ||
| [BinarySample(True, j) for j in range(1, 1001)], 10)] | ||
|
|
||
| # Create the QueueInputDStream and use it do some processing. | ||
| input_stream = ssc.queueStream(rdd_queue) | ||
|
|
||
| model = StreamingTest() | ||
| test_result = model.registerStream(input_stream) | ||
|
|
||
| test_result.pprint() | ||
|
|
||
| ssc.start() | ||
| time.sleep(12) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? Doesn't seem to be required for |
||
| ssc.stop(stopSparkContext=True, stopGraceFully=True) | ||
| try: | ||
| rmtree(checkpoint_path) | ||
| except OSError: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,7 @@ import org.apache.spark.mllib.regression._ | |
| import org.apache.spark.mllib.stat.{KernelDensity, MultivariateStatisticalSummary, Statistics} | ||
| import org.apache.spark.mllib.stat.correlation.CorrelationNames | ||
| import org.apache.spark.mllib.stat.distribution.MultivariateGaussian | ||
| import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} | ||
| import org.apache.spark.mllib.stat.test._ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we using a wildcard import here? |
||
| import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} | ||
| import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} | ||
| import org.apache.spark.mllib.tree.impurity._ | ||
|
|
@@ -55,6 +55,7 @@ import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} | |
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.streaming.api.java.JavaDStream | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| /** | ||
|
|
@@ -1258,7 +1259,11 @@ private[spark] abstract class SerDeBase { | |
| extends IObjectPickler with IObjectConstructor { | ||
|
|
||
| private val cls = implicitly[ClassTag[T]].runtimeClass | ||
| private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4) | ||
|
|
||
| // drop 4 to remove "org.apache.spark.mllib", while dropRight 1 to remove class simple name. | ||
| private val interPath = cls.getName.split('.').drop(4).dropRight(1).mkString(".") | ||
| private val module = PYSPARK_PACKAGE + "." + interPath | ||
|
|
||
| private val name = cls.getSimpleName | ||
|
|
||
| // register this to Pickler and Unpickler | ||
|
|
@@ -1358,6 +1363,41 @@ private[spark] abstract class SerDeBase { | |
| } | ||
| }.toJavaRDD() | ||
| } | ||
|
|
||
| /** | ||
| * Convert a DStream of Java objects to a DStream of serialized Python objects, that is usable by | ||
| * PySpark. | ||
| */ | ||
| def javaToPython(jDStream: JavaDStream[Any]): JavaDStream[Array[Byte]] = { | ||
| val dStream = jDStream.dstream.mapPartitions { iter => | ||
| initialize() // let it called in executor | ||
| new SerDeUtil.AutoBatchedPickler(iter) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an |
||
| } | ||
| new JavaDStream[Array[Byte]](dStream) | ||
| } | ||
|
|
||
| /** | ||
| * Convert a DStream of serialized Python objects to a DStream of objects, that is usable by | ||
| * PySpark. | ||
| */ | ||
| def pythonToJava(pyDStream: JavaDStream[Array[Byte]], batched: Boolean): JavaDStream[Any] = { | ||
| val dStream = pyDStream.dstream.mapPartitions { iter => | ||
| initialize() // let it called in executor | ||
| val unpickle = new Unpickler | ||
| iter.flatMap { row => | ||
| val obj = unpickle.loads(row) | ||
| if (batched) { | ||
| obj match { | ||
| case list: JArrayList[_] => list.asScala | ||
| case arr: Array[_] => arr | ||
| } | ||
| } else { | ||
| Seq(obj) | ||
| } | ||
| } | ||
| } | ||
| new JavaDStream[Any](dStream) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for this constructor, just return |
||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -1576,17 +1616,49 @@ private[spark] object SerDe extends SerDeBase with Serializable { | |
| } | ||
| } | ||
|
|
||
| private[python] class BinarySamplePickler extends BasePickler[BinarySample] { | ||
| def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { | ||
| val binarySample = obj.asInstanceOf[BinarySample] | ||
| saveObjects(out, pickler, binarySample.isExperiment, binarySample.value) | ||
| } | ||
|
|
||
| def construct(args: Array[AnyRef]): AnyRef = { | ||
| if (args.length != 2) { | ||
| throw new PickleException("should be 2") | ||
| } | ||
| BinarySample(args(0).asInstanceOf[Boolean], args(1).asInstanceOf[Double]) | ||
| } | ||
| } | ||
|
|
||
| private[python] class StreamingTestResultPickler extends BasePickler[StreamingTestResult] { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to test these in |
||
| def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { | ||
| val result = obj.asInstanceOf[StreamingTestResult] | ||
| saveObjects(out, pickler, result.pValue, result.degreesOfFreedom, result.statistic, | ||
| result.method, result.nullHypothesis) | ||
| } | ||
|
|
||
| def construct(args: Array[AnyRef]): AnyRef = { | ||
| if (args.length != 5) { | ||
| throw new PickleException("should be 5") | ||
| } | ||
| new StreamingTestResult(args(0).asInstanceOf[Double], args(1).asInstanceOf[Double], | ||
| args(2).asInstanceOf[Double], args(3).asInstanceOf[String], args(4).asInstanceOf[String]) | ||
| } | ||
| } | ||
|
|
||
| var initialized = false | ||
| // This should be called before trying to serialize any above classes | ||
| // In cluster mode, this should be put in the closure | ||
| override def initialize(): Unit = { | ||
| SerDeUtil.initialize() | ||
| synchronized { | ||
| if (!initialized) { | ||
| new BinarySamplePickler().register() | ||
| new DenseVectorPickler().register() | ||
| new DenseMatrixPickler().register() | ||
| new SparseMatrixPickler().register() | ||
| new SparseVectorPickler().register() | ||
| new StreamingTestResultPickler().register() | ||
| new LabeledPointPickler().register() | ||
| new RatingPickler().register() | ||
| initialized = true | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,7 +114,7 @@ class KolmogorovSmirnovTestResult private[stat] ( | |
| * Object containing the test results for streaming testing. | ||
| */ | ||
| @Since("1.6.0") | ||
| private[stat] class StreamingTestResult @Since("1.6.0") ( | ||
| class StreamingTestResult @Since("1.6.0") ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be public? Java API doesn't seem to require it |
||
| @Since("1.6.0") override val pValue: Double, | ||
| @Since("1.6.0") override val degreesOfFreedom: Double, | ||
| @Since("1.6.0") override val statistic: Double, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -15,10 +15,15 @@ | |||
| # limitations under the License. | ||||
| # | ||||
|
|
||||
| from collections import namedtuple | ||||
|
|
||||
| from pyspark import SparkContext, since | ||||
| from pyspark.mllib.common import inherit_doc, JavaModelWrapper | ||||
| from pyspark.streaming.dstream import DStream | ||||
|
|
||||
|
|
||||
| __all__ = ["ChiSqTestResult", "KolmogorovSmirnovTestResult"] | ||||
| __all__ = ["ChiSqTestResult", "KolmogorovSmirnovTestResult", "BinarySample", "StreamingTest", | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: alphabetize |
||||
| "StreamingTestResult"] | ||||
|
|
||||
|
|
||||
| class TestResult(JavaModelWrapper): | ||||
|
|
@@ -80,3 +85,118 @@ class KolmogorovSmirnovTestResult(TestResult): | |||
| """ | ||||
| Contains test results for the Kolmogorov-Smirnov test. | ||||
| """ | ||||
|
|
||||
|
|
||||
| class BinarySample(namedtuple("BinarySample", ["isExperiment", "value"])): | ||||
| """ | ||||
| Represents a (isExperiment, value) tuple. | ||||
|
|
||||
| .. versionadded:: 2.0.0 | ||||
| """ | ||||
|
|
||||
| def __reduce__(self): | ||||
| return BinarySample, (bool(self.isExperiment), float(self.value)) | ||||
|
|
||||
|
|
||||
| class StreamingTestResult(namedtuple("StreamingTestResult", | ||||
| ["pValue", "degreesOfFreedom", "statistic", "method", | ||||
| "nullHypothesis"])): | ||||
| """ | ||||
| Contains test results for StreamingTest. | ||||
|
|
||||
| .. versionadded:: 2.0.0 | ||||
| """ | ||||
|
|
||||
| def __reduce__(self): | ||||
| return StreamingTestResult, (float(self.pValue), | ||||
| float(self.degreesOfFreedom), float(self.statistic), | ||||
| str(self.method), str(self.nullHypothesis)) | ||||
|
|
||||
|
|
||||
| class StreamingTest(object): | ||||
| """ | ||||
| .. note:: Experimental | ||||
|
|
||||
| Online 2-sample significance testing for a stream of (Boolean, Double) pairs. The Boolean | ||||
| identifies which sample each observation comes from, and the Double is the numeric value of the | ||||
| observation. | ||||
|
|
||||
| To address novelty affects, the `peacePeriod` specifies a set number of initial RDD batches of | ||||
| the DStream to be dropped from significance testing. | ||||
|
|
||||
| The `windowSize` sets the number of batches each significance test is to be performed over. The | ||||
| window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform | ||||
| cumulative processing, using all batches seen so far. | ||||
|
|
||||
| Different tests may be used for assessing statistical significance depending on assumptions | ||||
| satisfied by data. For more details, see StreamingTestMethod. The `testMethod` specifies | ||||
| which test will be used. | ||||
|
|
||||
| .. versionadded:: 2.0.0 | ||||
| """ | ||||
|
|
||||
| def __init__(self): | ||||
| self._peacePeriod = 0 | ||||
| self._windowSize = 0 | ||||
| self._testMethod = "welch" | ||||
|
|
||||
| @since('2.0.0') | ||||
| def setPeacePeriod(self, peacePeriod): | ||||
| """ | ||||
| Update peacePeriod | ||||
| :param peacePeriod: | ||||
| Set number of initial RDD batches of the DStream to be dropped from significance testing. | ||||
| """ | ||||
| self._peacePeriod = peacePeriod | ||||
|
|
||||
| @since('2.0.0') | ||||
| def setWindowSize(self, windowSize): | ||||
| """ | ||||
| Update windowSize | ||||
| :param windowSize: | ||||
| Set the number of batches each significance test is to be performed over. | ||||
| """ | ||||
| self._windowSize = windowSize | ||||
|
|
||||
| @since('2.0.0') | ||||
| def setTestMethod(self, testMethod): | ||||
| """ | ||||
| Update test method | ||||
| :param testMethod: | ||||
| Currently supported tests: `welch`, `student`. | ||||
| """ | ||||
| assert(testMethod in ("welch", "student"), | ||||
| "Currently supported tests: \"welch\", \"student\"") | ||||
| self._testMethod = testMethod | ||||
|
|
||||
| @since('2.0.0') | ||||
| def registerStream(self, data): | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: data -> dstream |
||||
| """ | ||||
| Register a data stream to get its test result. | ||||
|
|
||||
| :param data: | ||||
| The input data stream, each element is a BinarySample instance. | ||||
| """ | ||||
| self._validate(data) | ||||
| sc = SparkContext._active_spark_context | ||||
|
|
||||
| streamingTest = sc._jvm.org.apache.spark.mllib.stat.test.StreamingTest() | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did we not define a |
||||
| streamingTest.setPeacePeriod(self._peacePeriod) | ||||
| streamingTest.setWindowSize(self._windowSize) | ||||
| streamingTest.setTestMethod(self._testMethod) | ||||
|
|
||||
| javaDStream = sc._jvm.SerDe.pythonToJava(data._jdstream, True) | ||||
| testResult = streamingTest.registerStream(javaDStream) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need spark/python/pyspark/mllib/clustering.py Line 773 in 39e2bad
|
||||
| pythonTestResult = sc._jvm.SerDe.javaToPython(testResult) | ||||
|
|
||||
| pyResult = DStream(pythonTestResult, data._ssc, data._jrdd_deserializer) | ||||
|
|
||||
| return pyResult | ||||
|
|
||||
| @classmethod | ||||
| def _validate(cls, samples): | ||||
| if isinstance(samples, DStream): | ||||
| pass | ||||
| else: | ||||
| raise TypeError("BinarySample should be represented by a DStream, " | ||||
| "but got %s." % type(samples)) | ||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -59,14 +59,15 @@ | |||
| from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD | ||||
| from pyspark.mllib.random import RandomRDDs | ||||
| from pyspark.mllib.stat import Statistics | ||||
| from pyspark.mllib.stat.test import BinarySample, StreamingTest, StreamingTestResult | ||||
| from pyspark.mllib.feature import HashingTF | ||||
| from pyspark.mllib.feature import Word2Vec | ||||
| from pyspark.mllib.feature import IDF | ||||
| from pyspark.mllib.feature import StandardScaler, ElementwiseProduct | ||||
| from pyspark.mllib.util import LinearDataGenerator | ||||
| from pyspark.mllib.util import MLUtils | ||||
| from pyspark.serializers import PickleSerializer | ||||
| from pyspark.streaming import StreamingContext | ||||
| from pyspark.streaming.tests import PySparkStreamingTestCase | ||||
| from pyspark.sql import SparkSession | ||||
| from pyspark.sql.utils import IllegalArgumentException | ||||
| from pyspark.streaming import StreamingContext | ||||
|
|
@@ -1688,6 +1689,44 @@ def test_binary_term_freqs(self): | |||
| ": expected " + str(expected[i]) + ", got " + str(output[i])) | ||||
|
|
||||
|
|
||||
| class StreamingTestTest(PySparkStreamingTestCase): | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not |
||||
| def test_streaming_test_result_and_model(self): | ||||
| """ | ||||
| Assert the StreamingTest return valid result, and the set method of it. | ||||
| """ | ||||
|
|
||||
| checkpoint_path = tempfile.mkdtemp() | ||||
| self.ssc.checkpoint(checkpoint_path) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? |
||||
|
|
||||
| # Create the queue through which RDDs can be pushed to a QueueInputDStream. | ||||
| rdd_queue = [] | ||||
| for i in range(5): | ||||
| rdd_queue += [self.ssc.sparkContext.parallelize( | ||||
| [BinarySample(True, j) for j in range(1, 1001)], 10)] | ||||
|
|
||||
| # Create the QueueInputDStream and use it do some processing. | ||||
| input_stream = self.ssc.queueStream(rdd_queue) | ||||
|
|
||||
| model = StreamingTest() | ||||
| model.setPeacePeriod(1) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we break this into another test just for model params like spark/python/pyspark/mllib/tests.py Line 1165 in 39e2bad
|
||||
| model.setWindowSize(2) | ||||
| model.setTestMethod("student") | ||||
|
|
||||
| test_result = model.registerStream(input_stream) | ||||
| res = self._take(test_result, 1)[0] | ||||
| self.assertTrue(isinstance(res, StreamingTestResult)) | ||||
| self.assertEqual(res.method, "Student's 2-sample t-test") | ||||
|
|
||||
| self.assertEqual(model._peacePeriod, 1) | ||||
| self.assertEqual(model._windowSize, 2) | ||||
| self.assertEqual(model._testMethod, "student") | ||||
|
|
||||
| try: | ||||
| rmtree(checkpoint_path) | ||||
| except OSError: | ||||
| pass | ||||
|
|
||||
|
|
||||
| if __name__ == "__main__": | ||||
| from pyspark.mllib.tests import * | ||||
| if not _have_scipy: | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like other examples are including a
from __future__ import print_functionhere