From c7f39dad12327f1e224930377df14a40dbb15c1d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sat, 10 Jan 2015 14:36:27 -0500 Subject: [PATCH 1/5] Add broken implementation of AB testing. Fix AB testing implementation and add unit tests. Extract t-testing code out of OnlineABTesting. Add peace period for dropping first k entries of each A/B group. Add numDim to MultivariateOnlineSummarizer. Refactored ABTestingMethod into sealed trait. Add (non-sliding) testing window functionality. Fix peace period implementation. Fix test window batching. Handle (inelegantly) closure capture for ABTestMethod Improve handling of OnlineABTestMethod closure by moving DStream processing method into Serializable class. Fixed flaky peacePeriod test. Add ScalaDocs and format to style guide. Add OnlineABTestExample. Format code to style guide. Switch MultivariateOnlineSummarizer to univariate StatsCounter. Reduce number of passes in pairSummaries. Add test for behavior when missing data from one group. Remove numDim from MultivariateOnlineSummarizer. Style guide in OnlineABTestSuite Fix broken tests Style fix Fix runStream expectedOutput --- .../examples/mllib/OnlineABTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/OnlineABTest.scala | 136 ++++++++++ .../mllib/stat/test/OnlineABTestMethod.scala | 168 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/OnlineABTestSuite.scala | 244 ++++++++++++++++++ 5 files changed, 660 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala new file mode 100644 index 000000000000..f38892b06f5b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.OnlineABTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform online A/B testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * OnlineABTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.OnlineABTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object OnlineABTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: OnlineABTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("OnlineABTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val ABTest = new OnlineABTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = ABTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala new file mode 100644 index 000000000000..70a56bac16b1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: DeveloperApi :: + * Performs online significance testing for a stream of A/B testing results. + * + * To address novelty affects, the peacePeriod specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The windowSize sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[OnlineABTestMethod]]. The testMethod specifies which + * test will be used. + * + * Use a builder pattern to construct an online A/B test in an application, like: + * + * val model = new OnlineABTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + */ +@DeveloperApi +class OnlineABTest( + var peacePeriod: Int = 0, + var windowSize: Int = 0, + var testMethod: OnlineABTestMethod = WelchTTest) extends Logging with Serializable { + + /** Set the number of initial batches to ignore. */ + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. + * A value of 0 will use all batches seen so far. + */ + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. */ + def setTestMethod(method: String): this.type = { + this.testMethod = OnlineABTestMethodNames.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key is the group membership (control or + * treatment) and the value is the numerical metric to test for significance + * @return stream of significance testing results + */ + def registerStream(data: DStream[(Boolean, Double)]): DStream[OnlineABTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + val testResults = testMethod.doTest(pairedSummaries) + + testResults + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + rdd.filter(_ => false) // TODO: Is there a better way to drop a RDD from a DStream? + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for group A and + * group B up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // Iterable[StatCounter] should be length two, one for each A/B group + .map(x => (x._2.head, x._2.last) ) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala new file mode 100644 index 000000000000..4ab3436f48ce --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[OnlineABTest]]. New statistical tests for assessing + * significance of AB testing results should be implemented in OnlineABTestMethod.scala, extend + * [[OnlineABTestMethod]], and introduce a new entry in [[OnlineABTestMethodNames.NameToObjectMap]]. + */ +sealed trait OnlineABTestMethod extends Serializable { + + val MethodName: String + val NullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform online 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[OnlineABTestResult] + + + /** + * Implicit adapter to convert between online summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * More information: http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends OnlineABTestMethod with Logging { + + final val MethodName = "Welch's 2-sample T-test" + final val NullHypothesis = "A and B groups have same mean" + + private final val TTester = MeatLocker(new TTest()) + + def doTest(data: SummaryPairStream): DStream[OnlineABTestResult] = + data.map[OnlineABTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): OnlineABTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new OnlineABTestResult( + TTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + TTester.get.t(statsA, statsB), + MethodName, + NullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * More information: http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends OnlineABTestMethod with Logging { + + final val MethodName = "Student's 2-sample T-test" + final val NullHypothesis = "A and B groups have same mean" + + private final val TTester = MeatLocker(new TTest()) + + def doTest(data: SummaryPairStream): DStream[OnlineABTestResult] = + data.map[OnlineABTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): OnlineABTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new OnlineABTestResult( + TTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + TTester.get.homoscedasticT(statsA, statsB), + MethodName, + NullHypothesis + ) + } +} + +/** + * Maintains supported [[OnlineABTestMethod]] names and handles conversion between strings used in + * [[OnlineABTest]] configuration and actual method implementation. + * + * Currently supported correlations: `welch`, `student`. + */ +private[stat] object OnlineABTestMethodNames { + // Note: after new OnlineABTestMethods are implemented, please update this map. + final val NameToObjectMap = Map(("welch", WelchTTest), ("student", StudentTTest)) + + // Match input correlation name with a known name via simple string matching. + def getTestMethodFromName(method: String): OnlineABTestMethod = { + try { + NameToObjectMap(method) + } catch { + case nse: NoSuchElementException => + throw new IllegalArgumentException("Unrecognized method name. Supported A/B test methods: " + + NameToObjectMap.keys.mkString(", ")) + } + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be94..eb9a0cf482c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for online A/B testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class OnlineABTestResult( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "A/B test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala new file mode 100644 index 000000000000..fc5438fdbc79 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.StatCounter + +import org.apache.spark.mllib.stat.test.{OnlineABTest, OnlineABTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom + +class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.MethodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.MethodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.MethodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.MethodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new OnlineABTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, 3) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new OnlineABTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From 249341874112485c26c5e1965a74c60c443d13cd Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 11 Sep 2015 12:25:57 -0700 Subject: [PATCH 2/5] Fixes tests --- .../scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala index fc5438fdbc79..8847446f1757 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala @@ -159,7 +159,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) - val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, 3) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, 4) val outputCounts = outputBatches.flatten.map(_._2.count) // number of batches seen so far does not exceed testWindow, expect counts to continue growing From b81bb53ae79c19e2990a7781436b83d0b53ab1a4 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 11 Sep 2015 12:52:49 -0700 Subject: [PATCH 3/5] Renames to StreamingTest and improves docs --- ...ample.scala => StreamingTestExample.scala} | 18 ++--- ...OnlineABTest.scala => StreamingTest.scala} | 51 ++++++++------ ...Method.scala => StreamingTestMethod.scala} | 69 +++++++++---------- .../spark/mllib/stat/test/TestResult.scala | 6 +- ...stSuite.scala => StreamingTestSuite.scala} | 31 ++++----- 5 files changed, 90 insertions(+), 85 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/mllib/{OnlineABTestExample.scala => StreamingTestExample.scala} (84%) rename mllib/src/main/scala/org/apache/spark/mllib/stat/test/{OnlineABTest.scala => StreamingTest.scala} (76%) rename mllib/src/main/scala/org/apache/spark/mllib/stat/test/{OnlineABTestMethod.scala => StreamingTestMethod.scala} (63%) rename mllib/src/test/scala/org/apache/spark/mllib/stat/{OnlineABTestSuite.scala => StreamingTestSuite.scala} (89%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala similarity index 84% rename from examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index f38892b06f5b..ab29f90254d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/OnlineABTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -18,12 +18,12 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf -import org.apache.spark.mllib.stat.test.OnlineABTest +import org.apache.spark.mllib.stat.test.StreamingTest import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.util.Utils /** - * Perform online A/B testing using Welch's 2-sample t-test on a stream of data, where the data + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data * stream arrives as text files in a directory. Stops when the two groups are statistically * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. * @@ -32,23 +32,23 @@ import org.apache.spark.util.Utils * true, 99.32 * * Usage: - * OnlineABTestExample + * StreamingTestExample * * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and * a timeout after 100 insignificant batches, call: - * $ bin/run-example mllib.OnlineABTestExample dataDir 5 100 + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 * * As you add text files to `dataDir` the significance test wil continually update every * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of * batches processed exceeds `numBatchesTimeout`. */ -object OnlineABTestExample { +object StreamingTestExample { def main(args: Array[String]) { if (args.length != 3) { // scalastyle:off println System.err.println( - "Usage: OnlineABTestExample " + + "Usage: StreamingTestExample " + " ") // scalastyle:on println System.exit(1) @@ -57,7 +57,7 @@ object OnlineABTestExample { val batchDuration = Seconds(args(1).toLong) val numBatchesTimeout = args(2).toInt - val conf = new SparkConf().setMaster("local").setAppName("OnlineABTestExample") + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") val ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint({ val dir = Utils.createTempDir() @@ -68,12 +68,12 @@ object OnlineABTestExample { case Array(label, value) => (label.toBoolean, value.toDouble) }) - val ABTest = new OnlineABTest() + val streamingTest = new StreamingTest() .setPeacePeriod(0) .setWindowSize(0) .setTestMethod("welch") - val out = ABTest.registerStream(data) + val out = streamingTest.registerStream(data) out.print() // Stop processing if test becomes significant or we time out diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala similarity index 76% rename from mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala rename to mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 70a56bac16b1..86a1a1753417 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -18,40 +18,46 @@ package org.apache.spark.mllib.stat.test import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter /** - * :: DeveloperApi :: - * Performs online significance testing for a stream of A/B testing results. + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. * - * To address novelty affects, the peacePeriod specifies a set number of initial + * To address novelty affects, the `peacePeriod` specifies a set number of initial * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. * - * The windowSize sets the number of batches each significance test is to be performed over. The + * The `windowSize` sets the number of batches each significance test is to be performed over. The * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform * cumulative processing, using all batches seen so far. * * Different tests may be used for assessing statistical significance depending on assumptions - * satisfied by data. For more details, see [[OnlineABTestMethod]]. The testMethod specifies which - * test will be used. + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. * - * Use a builder pattern to construct an online A/B test in an application, like: - * - * val model = new OnlineABTest() - * .setPeacePeriod(10) - * .setWindowSize(0) - * .setTestMethod("welch") - * .registerStream(DStream) + * Use a builder pattern to construct a streaming test in an application, for example: + * ``` + * val model = new OnlineABTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * ``` */ -@DeveloperApi -class OnlineABTest( - var peacePeriod: Int = 0, - var windowSize: Int = 0, - var testMethod: OnlineABTestMethod = WelchTTest) extends Logging with Serializable { +@Experimental +@Since("1.6.0") +class StreamingTest( + @Since("1.6.0") var peacePeriod: Int = 0, + @Since("1.6.0") var windowSize: Int = 0, + @Since("1.6.0") var testMethod: StreamingTestMethod = WelchTTest) + extends Logging with Serializable { /** Set the number of initial batches to ignore. */ + @Since("1.6.0") def setPeacePeriod(peacePeriod: Int): this.type = { this.peacePeriod = peacePeriod this @@ -61,14 +67,16 @@ class OnlineABTest( * Set the number of batches to compute significance tests over. * A value of 0 will use all batches seen so far. */ + @Since("1.6.0") def setWindowSize(windowSize: Int): this.type = { this.windowSize = windowSize this } /** Set the statistical method used for significance testing. */ + @Since("1.6.0") def setTestMethod(method: String): this.type = { - this.testMethod = OnlineABTestMethodNames.getTestMethodFromName(method) + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) this } @@ -79,7 +87,8 @@ class OnlineABTest( * treatment) and the value is the numerical metric to test for significance * @return stream of significance testing results */ - def registerStream(data: DStream[(Boolean, Double)]): DStream[OnlineABTestResult] = { + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala similarity index 63% rename from mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala rename to mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index 4ab3436f48ce..23e0dead0297 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/OnlineABTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -31,11 +31,11 @@ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter /** - * Significance testing methods for [[OnlineABTest]]. New statistical tests for assessing - * significance of AB testing results should be implemented in OnlineABTestMethod.scala, extend - * [[OnlineABTestMethod]], and introduce a new entry in [[OnlineABTestMethodNames.NameToObjectMap]]. + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] */ -sealed trait OnlineABTestMethod extends Serializable { +private[stat] sealed trait StreamingTestMethod extends Serializable { val MethodName: String val NullHypothesis: String @@ -44,16 +44,15 @@ sealed trait OnlineABTestMethod extends Serializable { DStream[(StatCounter, StatCounter)] /** - * Perform online 2-sample statistical significance testing. + * Perform streaming 2-sample statistical significance testing. * * @param sampleSummaries stream pairs of summary statistics for the 2 samples * @return stream of rest results */ - def doTest(sampleSummaries: SummaryPairStream): DStream[OnlineABTestResult] - + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] /** - * Implicit adapter to convert between online summary statistics type and the type required by + * Implicit adapter to convert between streaming summary statistics type and the type required by * the t-testing libraries. */ protected implicit def toApacheCommonsStats( @@ -76,19 +75,19 @@ sealed trait OnlineABTestMethod extends Serializable { * * More information: http://en.wikipedia.org/wiki/Welch%27s_t_test */ -private[stat] object WelchTTest extends OnlineABTestMethod with Logging { +private[stat] object WelchTTest extends StreamingTestMethod with Logging { final val MethodName = "Welch's 2-sample T-test" - final val NullHypothesis = "A and B groups have same mean" + final val NullHypothesis = "Both groups have same mean" private final val TTester = MeatLocker(new TTest()) - def doTest(data: SummaryPairStream): DStream[OnlineABTestResult] = - data.map[OnlineABTestResult]((test _).tupled) + def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) private def test( statsA: StatCounter, - statsB: StatCounter): OnlineABTestResult = { + statsB: StatCounter): StreamingTestResult = { def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { val s1 = sample1.getVariance val n1 = sample1.getN @@ -101,7 +100,7 @@ private[stat] object WelchTTest extends OnlineABTestMethod with Logging { pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) } - new OnlineABTestResult( + new StreamingTestResult( TTester.get.tTest(statsA, statsB), welchDF(statsA, statsB), TTester.get.t(statsA, statsB), @@ -118,23 +117,23 @@ private[stat] object WelchTTest extends OnlineABTestMethod with Logging { * * More information: http://en.wikipedia.org/wiki/Student%27s_t-test */ -private[stat] object StudentTTest extends OnlineABTestMethod with Logging { +private[stat] object StudentTTest extends StreamingTestMethod with Logging { final val MethodName = "Student's 2-sample T-test" - final val NullHypothesis = "A and B groups have same mean" + final val NullHypothesis = "Both groups have same mean" private final val TTester = MeatLocker(new TTest()) - def doTest(data: SummaryPairStream): DStream[OnlineABTestResult] = - data.map[OnlineABTestResult]((test _).tupled) + def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) private def test( statsA: StatCounter, - statsB: StatCounter): OnlineABTestResult = { + statsB: StatCounter): StreamingTestResult = { def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = sample1.getN + sample2.getN - 2 - new OnlineABTestResult( + new StreamingTestResult( TTester.get.homoscedasticTTest(statsA, statsB), studentDF(statsA, statsB), TTester.get.homoscedasticT(statsA, statsB), @@ -145,24 +144,22 @@ private[stat] object StudentTTest extends OnlineABTestMethod with Logging { } /** - * Maintains supported [[OnlineABTestMethod]] names and handles conversion between strings used in - * [[OnlineABTest]] configuration and actual method implementation. + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. * - * Currently supported correlations: `welch`, `student`. + * Currently supported tests: `welch`, `student`. */ -private[stat] object OnlineABTestMethodNames { - // Note: after new OnlineABTestMethods are implemented, please update this map. - final val NameToObjectMap = Map(("welch", WelchTTest), ("student", StudentTTest)) - - // Match input correlation name with a known name via simple string matching. - def getTestMethodFromName(method: String): OnlineABTestMethod = { - try { - NameToObjectMap(method) - } catch { - case nse: NoSuchElementException => - throw new IllegalArgumentException("Unrecognized method name. Supported A/B test methods: " - + NameToObjectMap.keys.mkString(", ")) +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + final val TEST_NAME_TO_OBJECT = Map(("welch", WelchTTest), ("student", StudentTTest)) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index eb9a0cf482c5..89248e9e6697 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -118,11 +118,11 @@ class KolmogorovSmirnovTestResult private[stat] ( /** * :: Experimental :: - * Object containing the test results for online A/B testing. + * Object containing the test results for streaming testing. */ @Experimental @Since("1.6.0") -private[stat] class OnlineABTestResult( +private[stat] class StreamingTestResult( @Since("1.6.0") override val pValue: Double, @Since("1.6.0") override val degreesOfFreedom: Double, @Since("1.6.0") override val statistic: Double, @@ -131,7 +131,7 @@ private[stat] class OnlineABTestResult( extends TestResult[Double] with Serializable { override def toString: String = { - "A/B test summary:\n" + + "Streaming test summary:\n" + s"method: $method\n" + super.toString } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala similarity index 89% rename from mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 8847446f1757..55748232a61a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/OnlineABTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.util.StatCounter - -import org.apache.spark.mllib.stat.test.{OnlineABTest, OnlineABTestResult, StudentTTest, WelchTTest} +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter import org.apache.spark.util.random.XORShiftRandom -class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis : Int = 30000 @@ -39,7 +38,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 0 val stdevB = 0.001 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(0) .setTestMethod(testMethod) @@ -50,7 +49,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) - val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => res.pValue > 0.05 && res.method == WelchTTest.MethodName)) @@ -66,7 +65,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 10 val stdevB = 1 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(0) .setTestMethod(testMethod) @@ -77,7 +76,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) - val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => res.pValue < 0.05 && res.method == WelchTTest.MethodName)) @@ -93,7 +92,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 0 val stdevB = 0.001 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(0) .setTestMethod(testMethod) @@ -104,7 +103,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) - val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -121,7 +120,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 10 val stdevB = 1 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(0) .setTestMethod(testMethod) @@ -132,7 +131,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) - val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => res.pValue < 0.05 && res.method == StudentTTest.MethodName)) @@ -148,7 +147,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 10 val stdevB = 1 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(testWindow) .setPeacePeriod(0) @@ -182,7 +181,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 10 val stdevB = 1 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(peacePeriod) @@ -206,7 +205,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { val meanB = 0 val stdevB = 0.001 - val model = new OnlineABTest() + val model = new StreamingTest() .setWindowSize(0) .setPeacePeriod(0) @@ -216,7 +215,7 @@ class OnlineABTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) - val outputBatches = runStreams[OnlineABTestResult](ssc, numBatches, numBatches) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) } From 60b2e57026febcb68e459983ba3164281a47f636 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 11 Sep 2015 13:21:10 -0700 Subject: [PATCH 4/5] Fixes flaky streaming numBatches --- .../scala/org/apache/spark/mllib/stat/test/StreamingTest.scala | 2 +- .../scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 86a1a1753417..01fef0b94616 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -53,7 +53,7 @@ import org.apache.spark.util.StatCounter class StreamingTest( @Since("1.6.0") var peacePeriod: Int = 0, @Since("1.6.0") var windowSize: Int = 0, - @Since("1.6.0") var testMethod: StreamingTestMethod = WelchTTest) + @Since("1.6.0") var testMethod: StreamingTestMethod = WelchTTest) extends Logging with Serializable { /** Set the number of initial batches to ignore. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 55748232a61a..da9e1de203cf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -158,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { val ssc = setupStreams( input, (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) - val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, 4) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) val outputCounts = outputBatches.flatten.map(_._2.count) // number of batches seen so far does not exceed testWindow, expect counts to continue growing From ba71bfad58d6aedb193e0f7b0cf32747d6a59ce2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 18 Sep 2015 13:32:36 -0700 Subject: [PATCH 5/5] Code review fixes --- .../spark/mllib/stat/test/StreamingTest.scala | 40 ++++++++--------- .../mllib/stat/test/StreamingTestMethod.scala | 44 ++++++++++--------- .../spark/mllib/stat/test/TestResult.scala | 2 +- .../spark/mllib/stat/StreamingTestSuite.scala | 8 ++-- 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 01fef0b94616..75c6a51d0957 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.stat.test import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -40,23 +41,22 @@ import org.apache.spark.util.StatCounter * which test will be used. * * Use a builder pattern to construct a streaming test in an application, for example: - * ``` - * val model = new OnlineABTest() + * {{{ + * val model = new StreamingTest() * .setPeacePeriod(10) * .setWindowSize(0) * .setTestMethod("welch") * .registerStream(DStream) - * ``` + * }}} */ @Experimental @Since("1.6.0") -class StreamingTest( - @Since("1.6.0") var peacePeriod: Int = 0, - @Since("1.6.0") var windowSize: Int = 0, - @Since("1.6.0") var testMethod: StreamingTestMethod = WelchTTest) - extends Logging with Serializable { +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest - /** Set the number of initial batches to ignore. */ + /** Set the number of initial batches to ignore. Default: 0. */ @Since("1.6.0") def setPeacePeriod(peacePeriod: Int): this.type = { this.peacePeriod = peacePeriod @@ -64,7 +64,7 @@ class StreamingTest( } /** - * Set the number of batches to compute significance tests over. + * Set the number of batches to compute significance tests over. Default: 0. * A value of 0 will use all batches seen so far. */ @Since("1.6.0") @@ -73,7 +73,7 @@ class StreamingTest( this } - /** Set the statistical method used for significance testing. */ + /** Set the statistical method used for significance testing. Default: "welch" */ @Since("1.6.0") def setTestMethod(method: String): this.type = { this.testMethod = StreamingTestMethod.getTestMethodFromName(method) @@ -83,8 +83,9 @@ class StreamingTest( /** * Register a [[DStream]] of values for significance testing. * - * @param data stream of (key,value) pairs where the key is the group membership (control or - * treatment) and the value is the numerical metric to test for significance + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance * @return stream of significance testing results */ @Since("1.6.0") @@ -92,9 +93,8 @@ class StreamingTest( val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) - val testResults = testMethod.doTest(pairedSummaries) - testResults + testMethod.doTest(pairedSummaries) } /** Drop all batches inside the peace period. */ @@ -104,7 +104,7 @@ class StreamingTest( if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { rdd } else { - rdd.filter(_ => false) // TODO: Is there a better way to drop a RDD from a DStream? + data.context.sparkContext.parallelize(Seq()) } } } @@ -132,14 +132,14 @@ class StreamingTest( } /** - * Transform a stream of summaries into pairs representing summary statistics for group A and - * group B up to this batch. + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. */ private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) : DStream[(StatCounter, StatCounter)] = { summarizedData .map[(Int, StatCounter)](x => (0, x._2)) - .groupByKey() // Iterable[StatCounter] should be length two, one for each A/B group - .map(x => (x._2.head, x._2.last) ) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index 23e0dead0297..a7eaed51b4d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -37,8 +37,8 @@ import org.apache.spark.util.StatCounter */ private[stat] sealed trait StreamingTestMethod extends Serializable { - val MethodName: String - val NullHypothesis: String + val methodName: String + val nullHypothesis: String protected type SummaryPairStream = DStream[(StatCounter, StatCounter)] @@ -73,16 +73,16 @@ private[stat] sealed trait StreamingTestMethod extends Serializable { * This test does not assume equal variance between the two samples and does not assume equal * sample size. * - * More information: http://en.wikipedia.org/wiki/Welch%27s_t_test + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test */ private[stat] object WelchTTest extends StreamingTestMethod with Logging { - final val MethodName = "Welch's 2-sample T-test" - final val NullHypothesis = "Both groups have same mean" + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" - private final val TTester = MeatLocker(new TTest()) + private final val tTester = MeatLocker(new TTest()) - def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = data.map[StreamingTestResult]((test _).tupled) private def test( @@ -101,11 +101,11 @@ private[stat] object WelchTTest extends StreamingTestMethod with Logging { } new StreamingTestResult( - TTester.get.tTest(statsA, statsB), + tTester.get.tTest(statsA, statsB), welchDF(statsA, statsB), - TTester.get.t(statsA, statsB), - MethodName, - NullHypothesis + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis ) } } @@ -115,16 +115,16 @@ private[stat] object WelchTTest extends StreamingTestMethod with Logging { * mean. This test assumes equal variance between the two samples and does not assume equal sample * size. For unequal variances, Welch's t-test should be used instead. * - * More information: http://en.wikipedia.org/wiki/Student%27s_t-test + * @see http://en.wikipedia.org/wiki/Student%27s_t-test */ private[stat] object StudentTTest extends StreamingTestMethod with Logging { - final val MethodName = "Student's 2-sample T-test" - final val NullHypothesis = "Both groups have same mean" + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" - private final val TTester = MeatLocker(new TTest()) + private final val tTester = MeatLocker(new TTest()) - def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = data.map[StreamingTestResult]((test _).tupled) private def test( @@ -134,11 +134,11 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { sample1.getN + sample2.getN - 2 new StreamingTestResult( - TTester.get.homoscedasticTTest(statsA, statsB), + tTester.get.homoscedasticTTest(statsA, statsB), studentDF(statsA, statsB), - TTester.get.homoscedasticT(statsA, statsB), - MethodName, - NullHypothesis + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis ) } } @@ -151,7 +151,9 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { */ private[stat] object StreamingTestMethod { // Note: after new `StreamingTestMethod`s are implemented, please update this map. - final val TEST_NAME_TO_OBJECT = Map(("welch", WelchTTest), ("student", StudentTTest)) + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) def getTestMethodFromName(method: String): StreamingTestMethod = TEST_NAME_TO_OBJECT.get(method) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 89248e9e6697..b0916d3e8465 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -122,7 +122,7 @@ class KolmogorovSmirnovTestResult private[stat] ( */ @Experimental @Since("1.6.0") -private[stat] class StreamingTestResult( +private[stat] class StreamingTestResult @Since("1.6.0") ( @Since("1.6.0") override val pValue: Double, @Since("1.6.0") override val degreesOfFreedom: Double, @Since("1.6.0") override val statistic: Double, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index da9e1de203cf..d3e9ef4ff079 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -52,7 +52,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => - res.pValue > 0.05 && res.method == WelchTTest.MethodName)) + res.pValue > 0.05 && res.method == WelchTTest.methodName)) } test("accuracy for alternative hypothesis using welch t-test") { @@ -79,7 +79,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => - res.pValue < 0.05 && res.method == WelchTTest.MethodName)) + res.pValue < 0.05 && res.method == WelchTTest.methodName)) } test("accuracy for null hypothesis using student t-test") { @@ -107,7 +107,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { assert(outputBatches.flatten.forall(res => - res.pValue > 0.05 && res.method == StudentTTest.MethodName)) + res.pValue > 0.05 && res.method == StudentTTest.methodName)) } test("accuracy for alternative hypothesis using student t-test") { @@ -134,7 +134,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => - res.pValue < 0.05 && res.method == StudentTTest.MethodName)) + res.pValue < 0.05 && res.method == StudentTTest.methodName)) } test("batches within same test window are grouped") {