Skip to content

Commit 03156ca

Browse files
author
Feynman Liang
committed
Fix tests, drop delimiters at boundaries of sequences
1 parent d1fe0ed commit 03156ca

File tree

2 files changed

+102
-113
lines changed

2 files changed

+102
-113
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class PrefixSpan private (
4848
* The maximum number of items allowed in a projected database before local processing. If a
4949
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
5050
*/
51-
// TODO: make configurable with a better default value, 10000 may be too small
52-
private val maxLocalProjDBSize: Long = 10000
51+
// TODO: make configurable with a better default value
52+
private val maxLocalProjDBSize: Long = 32000000L
5353

5454
/**
5555
* Constructs a default instance with default parameters
@@ -269,8 +269,8 @@ object PrefixSpan {
269269
// TODO: avoid allocating new arrays when appending
270270
sequence.zip(Seq.fill(sequence.size)(PrefixSpan.DELIMITER))
271271
.flatMap { case (a: Set[Int], b: Int) =>
272-
a.toList.sorted :+ b
273-
}
272+
b :: a.toList.sorted
273+
}.drop(1) // drop leading delimiter
274274
}
275275

276276
private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = {

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 98 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2121

2222
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
2323

24-
test("PrefixSpan using Integer type") {
24+
test("PrefixSpan using Integer type, singleton itemsets") {
2525

2626
/*
2727
library("arulesSequences")
@@ -35,12 +35,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
3535
*/
3636

3737
val sequences = Array(
38-
Array(1, 3, 4, 5),
39-
Array(2, 3, 1),
40-
Array(2, 4, 1),
41-
Array(3, 1, 3, 4, 5),
42-
Array(3, 4, 4, 3),
43-
Array(6, 5, 3)).map(insertDelimiter)
38+
Array(1, -1, 3, -1, 4, -1, 5),
39+
Array(2, -1, 3, -1, 1),
40+
Array(2, -1, 4, -1, 1),
41+
Array(3, -1, 1, -1, 3, -1, 4, -1, 5),
42+
Array(3, -1, 4, -1, 4, -1, 3),
43+
Array(6, -1, 5, -1, 3))
4444

4545
val rdd = sc.parallelize(sequences, 2).cache()
4646

@@ -50,135 +50,124 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
5050
val result1 = prefixspan.run(rdd)
5151
val expectedValue1 = Array(
5252
(Array(1), 4L),
53-
(Array(1, 3), 2L),
54-
(Array(1, 3, 4), 2L),
55-
(Array(1, 3, 4, 5), 2L),
56-
(Array(1, 3, 5), 2L),
57-
(Array(1, 4), 2L),
58-
(Array(1, 4, 5), 2L),
59-
(Array(1, 5), 2L),
53+
(Array(1, -1, 3), 2L),
54+
(Array(1, -1, 3, -1, 4), 2L),
55+
(Array(1, -1, 3, -1, 4, -1, 5), 2L),
56+
(Array(1, -1, 3, -1, 5), 2L),
57+
(Array(1, -1, 4), 2L),
58+
(Array(1, -1, 4, -1, 5), 2L),
59+
(Array(1, -1, 5), 2L),
6060
(Array(2), 2L),
61-
(Array(2, 1), 2L),
61+
(Array(2, -1, 1), 2L),
6262
(Array(3), 5L),
63-
(Array(3, 1), 2L),
64-
(Array(3, 3), 2L),
65-
(Array(3, 4), 3L),
66-
(Array(3, 4, 5), 2L),
67-
(Array(3, 5), 2L),
63+
(Array(3, -1, 1), 2L),
64+
(Array(3, -1, 3), 2L),
65+
(Array(3, -1, 4), 3L),
66+
(Array(3, -1, 4, -1, 5), 2L),
67+
(Array(3, -1, 5), 2L),
6868
(Array(4), 4L),
69-
(Array(4, 5), 2L),
69+
(Array(4, -1, 5), 2L),
7070
(Array(5), 3L)
71-
).map { case (seq, count) => (insertDelimiter(seq), count) }
71+
)
7272
compareResults(expectedValue1, result1.collect())
7373

7474
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
7575
val result2 = prefixspan.run(rdd)
7676
val expectedValue2 = Array(
7777
(Array(1), 4L),
7878
(Array(3), 5L),
79-
(Array(3, 4), 3L),
79+
(Array(3, -1, 4), 3L),
8080
(Array(4), 4L),
8181
(Array(5), 3L)
82-
).map { case (seq, count) => (insertDelimiter(seq), count) }
82+
)
8383
compareResults(expectedValue2, result2.collect())
8484

8585
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
8686
val result3 = prefixspan.run(rdd)
8787
val expectedValue3 = Array(
8888
(Array(1), 4L),
89-
(Array(1, 3), 2L),
90-
(Array(1, 4), 2L),
91-
(Array(1, 5), 2L),
92-
(Array(2, 1), 2L),
89+
(Array(1, -1, 3), 2L),
90+
(Array(1, -1, 4), 2L),
91+
(Array(1, -1, 5), 2L),
92+
(Array(2, -1, 1), 2L),
9393
(Array(2), 2L),
9494
(Array(3), 5L),
95-
(Array(3, 1), 2L),
96-
(Array(3, 3), 2L),
97-
(Array(3, 4), 3L),
98-
(Array(3, 5), 2L),
95+
(Array(3, -1, 1), 2L),
96+
(Array(3, -1, 3), 2L),
97+
(Array(3, -1, 4), 3L),
98+
(Array(3, -1, 5), 2L),
9999
(Array(4), 4L),
100-
(Array(4, 5), 2L),
100+
(Array(4, -1, 5), 2L),
101101
(Array(5), 3L)
102-
).map { case (seq, count) => (insertDelimiter(seq), count) }
102+
)
103103
compareResults(expectedValue3, result3.collect())
104104
}
105105

106-
test("PrefixSpan non-temporal sequences") {
106+
test("PrefixSpan using Integer type, variable-size itemsets") {
107107
val sequences = Array(
108-
"a,abc,ac,d,cf",
109-
"ad,c,bc,ae",
110-
"ef,ab,df,c,b",
111-
"e,g,af,c,b,c")
112-
val coder = Array('a', 'b', 'c', 'd', 'e', 'f', 'g').zip(Array(1, 2, 3, 4, 5, 6, 7)).toMap
113-
val intSequences = sequences.map(_.split(",").flatMap(_.toArray.map(coder) :+ -1))
114-
val data = sc.parallelize(intSequences, 2).cache()
115-
116-
val expectedValue4 = Array(
117-
"a:4",
118-
"b:4",
119-
"c:4",
120-
"d:3",
121-
"e:3",
122-
"f:3",
123-
"a,a:2",
124-
"a,b:4",
125-
"a,bc:2",
126-
"a,bc,a:2",
127-
"a,b,a:2",
128-
"a,b,c:2",
129-
"ab:2",
130-
"ab,c:2",
131-
"ab,d:2",
132-
"ab,d,c:2",
133-
"ab,f:2",
134-
"a,c:4",
135-
"a,c,a:2",
136-
"a,c,b:3",
137-
"a,c,c:3",
138-
"a,d:2",
139-
"a,d,c:2",
140-
"a,f:2",
141-
"b,a:2",
142-
"b,c:3",
143-
"bc:2",
144-
"bc,a:2",
145-
"b,d:2",
146-
"b,d,c:2",
147-
"b,f:2",
148-
"c,a:2",
149-
"c,b:3",
150-
"c,c:3",
151-
"d,b:2",
152-
"d,c:3",
153-
"d,c,b:2",
154-
"e,a:2",
155-
"e,a,b:2",
156-
"e,a,c:2",
157-
"e,a,c,b:2",
158-
"e,b:2",
159-
"e,b,c:2",
160-
"e,c:2",
161-
"e,c,b:2",
162-
"e,f:2",
163-
"e,f,b:2",
164-
"e,f,c:2",
165-
"e,f,c,b:2",
166-
"f,b:2",
167-
"f,b,c:2",
168-
"f,c:2",
169-
"f,c,b:2")
170-
val intExpectedValue = expectedValue4
171-
.map(_.split(":"))
172-
.map { x =>
173-
(x(0).split(",").flatMap(_.toArray.sorted.map(coder) :+ -1), x(1).toLong)
174-
}
175-
176-
val prefixspan = new PrefixSpan()
177-
.setMinSupport(0.5)
178-
.setMaxPatternLength(5)
179-
180-
val results = prefixspan.run(data)
181-
compareResults(intExpectedValue, results.collect())
108+
Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6),
109+
Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5),
110+
Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2),
111+
Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3))
112+
val rdd = sc.parallelize(sequences, 2).cache()
113+
val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5)
114+
val result = prefixspan.run(rdd)
115+
val expectedValue = Array(
116+
(Array(1), 4L),
117+
(Array(1, -1, 6), 2L),
118+
(Array(1, -1, 1), 2L),
119+
(Array(1, -1, 3), 4L),
120+
(Array(1, -1, 3, -1, 1), 2L),
121+
(Array(1, -1, 3, -1, 3), 3L),
122+
(Array(1, -1, 3, -1, 2), 3L),
123+
(Array(1, -1, 2), 4L),
124+
(Array(1, -1, 2, -1, 1), 2L),
125+
(Array(1, -1, 2, -1, 3), 2L),
126+
(Array(1, -1, 2, 3), 2L),
127+
(Array(1, -1, 2, 3, -1, 1), 2L),
128+
(Array(1, 2), 2L),
129+
(Array(1, 2, -1, 6), 2L),
130+
(Array(1, 2, -1, 3), 2L),
131+
(Array(1, 2, -1, 4), 2L),
132+
(Array(1, 2, -1, 4, -1, 3), 2L),
133+
(Array(1, -1, 4), 2L),
134+
(Array(1, -1, 4, -1, 3), 2L),
135+
(Array(2), 4L),
136+
(Array(2, -1, 6), 2L),
137+
(Array(2, -1, 1), 2L),
138+
(Array(2, -1, 3), 3L),
139+
(Array(2, 3), 2L),
140+
(Array(2, 3, -1, 1), 2L),
141+
(Array(2, -1, 4), 2L),
142+
(Array(2, -1, 4, -1, 3), 2L),
143+
(Array(3), 4L),
144+
(Array(3, -1, 1), 2L),
145+
(Array(3, -1, 3), 3L),
146+
(Array(3, -1, 2), 3L),
147+
(Array(4), 3L),
148+
(Array(4, -1, 3), 3L),
149+
(Array(4, -1, 3, -1, 2), 2L),
150+
(Array(4, -1, 2), 2L),
151+
(Array(5), 3L),
152+
(Array(5, -1, 6), 2L),
153+
(Array(5, -1, 6, -1, 3), 2L),
154+
(Array(5, -1, 6, -1, 3, -1, 2), 2L),
155+
(Array(5, -1, 6, -1, 2), 2L),
156+
(Array(5, -1, 1), 2L),
157+
(Array(5, -1, 1, -1, 3), 2L),
158+
(Array(5, -1, 1, -1, 3, -1, 2), 2L),
159+
(Array(5, -1, 1, -1, 2), 2L),
160+
(Array(5, -1, 3), 2L),
161+
(Array(5, -1, 3, -1, 2), 2L),
162+
(Array(5, -1, 2), 2L),
163+
(Array(5, -1, 2, -1, 3), 2L),
164+
(Array(6), 3L),
165+
(Array(6, -1, 3), 2L),
166+
(Array(6, -1, 3, -1, 2), 2L),
167+
(Array(6, -1, 2), 2L),
168+
(Array(6, -1, 2, -1, 3), 2L))
169+
170+
compareResults(expectedValue, result.collect())
182171
}
183172

184173
private def compareResults(

0 commit comments

Comments
 (0)