Skip to content

Commit aeef44a

Browse files
Feynman Liangmengxr
authored andcommitted
[SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing
Implementation of significance testing using Streaming API. Author: Feynman Liang <[email protected]> Author: Feynman Liang <[email protected]> Closes #4716 from feynmanliang/ab_testing.
1 parent ba882db commit aeef44a

File tree

5 files changed

+667
-0
lines changed

5 files changed

+667
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import org.apache.spark.SparkConf
21+
import org.apache.spark.mllib.stat.test.StreamingTest
22+
import org.apache.spark.streaming.{Seconds, StreamingContext}
23+
import org.apache.spark.util.Utils
24+
25+
/**
26+
* Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data
27+
* stream arrives as text files in a directory. Stops when the two groups are statistically
28+
* significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded.
29+
*
30+
* The rows of the text files must be in the form `Boolean, Double`. For example:
31+
* false, -3.92
32+
* true, 99.32
33+
*
34+
* Usage:
35+
* StreamingTestExample <dataDir> <batchDuration> <numBatchesTimeout>
36+
*
37+
* To run on your local machine using the directory `dataDir` with 5 seconds between each batch and
38+
* a timeout after 100 insignificant batches, call:
39+
* $ bin/run-example mllib.StreamingTestExample dataDir 5 100
40+
*
41+
* As you add text files to `dataDir` the significance test wil continually update every
42+
* `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of
43+
* batches processed exceeds `numBatchesTimeout`.
44+
*/
45+
object StreamingTestExample {
46+
47+
def main(args: Array[String]) {
48+
if (args.length != 3) {
49+
// scalastyle:off println
50+
System.err.println(
51+
"Usage: StreamingTestExample " +
52+
"<dataDir> <batchDuration> <numBatchesTimeout>")
53+
// scalastyle:on println
54+
System.exit(1)
55+
}
56+
val dataDir = args(0)
57+
val batchDuration = Seconds(args(1).toLong)
58+
val numBatchesTimeout = args(2).toInt
59+
60+
val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample")
61+
val ssc = new StreamingContext(conf, batchDuration)
62+
ssc.checkpoint({
63+
val dir = Utils.createTempDir()
64+
dir.toString
65+
})
66+
67+
val data = ssc.textFileStream(dataDir).map(line => line.split(",") match {
68+
case Array(label, value) => (label.toBoolean, value.toDouble)
69+
})
70+
71+
val streamingTest = new StreamingTest()
72+
.setPeacePeriod(0)
73+
.setWindowSize(0)
74+
.setTestMethod("welch")
75+
76+
val out = streamingTest.registerStream(data)
77+
out.print()
78+
79+
// Stop processing if test becomes significant or we time out
80+
var timeoutCounter = numBatchesTimeout
81+
out.foreachRDD { rdd =>
82+
timeoutCounter -= 1
83+
val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _)
84+
if (timeoutCounter == 0 || anySignificant) rdd.context.stop()
85+
}
86+
87+
ssc.start()
88+
ssc.awaitTermination()
89+
}
90+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat.test
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.annotation.{Experimental, Since}
22+
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.streaming.dstream.DStream
24+
import org.apache.spark.util.StatCounter
25+
26+
/**
27+
* :: Experimental ::
28+
* Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
29+
* Boolean identifies which sample each observation comes from, and the Double is the numeric value
30+
* of the observation.
31+
*
32+
* To address novelty affects, the `peacePeriod` specifies a set number of initial
33+
* [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing.
34+
*
35+
* The `windowSize` sets the number of batches each significance test is to be performed over. The
36+
* window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform
37+
* cumulative processing, using all batches seen so far.
38+
*
39+
* Different tests may be used for assessing statistical significance depending on assumptions
40+
* satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies
41+
* which test will be used.
42+
*
43+
* Use a builder pattern to construct a streaming test in an application, for example:
44+
* {{{
45+
* val model = new StreamingTest()
46+
* .setPeacePeriod(10)
47+
* .setWindowSize(0)
48+
* .setTestMethod("welch")
49+
* .registerStream(DStream)
50+
* }}}
51+
*/
52+
@Experimental
53+
@Since("1.6.0")
54+
class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
55+
private var peacePeriod: Int = 0
56+
private var windowSize: Int = 0
57+
private var testMethod: StreamingTestMethod = WelchTTest
58+
59+
/** Set the number of initial batches to ignore. Default: 0. */
60+
@Since("1.6.0")
61+
def setPeacePeriod(peacePeriod: Int): this.type = {
62+
this.peacePeriod = peacePeriod
63+
this
64+
}
65+
66+
/**
67+
* Set the number of batches to compute significance tests over. Default: 0.
68+
* A value of 0 will use all batches seen so far.
69+
*/
70+
@Since("1.6.0")
71+
def setWindowSize(windowSize: Int): this.type = {
72+
this.windowSize = windowSize
73+
this
74+
}
75+
76+
/** Set the statistical method used for significance testing. Default: "welch" */
77+
@Since("1.6.0")
78+
def setTestMethod(method: String): this.type = {
79+
this.testMethod = StreamingTestMethod.getTestMethodFromName(method)
80+
this
81+
}
82+
83+
/**
84+
* Register a [[DStream]] of values for significance testing.
85+
*
86+
* @param data stream of (key,value) pairs where the key denotes group membership (true =
87+
* experiment, false = control) and the value is the numerical metric to test for
88+
* significance
89+
* @return stream of significance testing results
90+
*/
91+
@Since("1.6.0")
92+
def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
93+
val dataAfterPeacePeriod = dropPeacePeriod(data)
94+
val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
95+
val pairedSummaries = pairSummaries(summarizedData)
96+
97+
testMethod.doTest(pairedSummaries)
98+
}
99+
100+
/** Drop all batches inside the peace period. */
101+
private[stat] def dropPeacePeriod(
102+
data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
103+
data.transform { (rdd, time) =>
104+
if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
105+
rdd
106+
} else {
107+
data.context.sparkContext.parallelize(Seq())
108+
}
109+
}
110+
}
111+
112+
/** Compute summary statistics over each key and the specified test window size. */
113+
private[stat] def summarizeByKeyAndWindow(
114+
data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
115+
if (this.windowSize == 0) {
116+
data.updateStateByKey[StatCounter](
117+
(newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
118+
val newSummary = oldSummary.getOrElse(new StatCounter())
119+
newSummary.merge(newValues)
120+
Some(newSummary)
121+
})
122+
} else {
123+
val windowDuration = data.slideDuration * this.windowSize
124+
data
125+
.groupByKeyAndWindow(windowDuration)
126+
.mapValues { values =>
127+
val summary = new StatCounter()
128+
values.foreach(value => summary.merge(value))
129+
summary
130+
}
131+
}
132+
}
133+
134+
/**
135+
* Transform a stream of summaries into pairs representing summary statistics for control group
136+
* and experiment group up to this batch.
137+
*/
138+
private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)])
139+
: DStream[(StatCounter, StatCounter)] = {
140+
summarizedData
141+
.map[(Int, StatCounter)](x => (0, x._2))
142+
.groupByKey() // should be length two (control/experiment group)
143+
.map(x => (x._2.head, x._2.last))
144+
}
145+
}

0 commit comments

Comments
 (0)