Skip to content

Commit f1114b9

Browse files
author
Feynman Liang
committed
Add -1 delimiter
1 parent 00fe756 commit f1114b9

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
2525
* Calculate all patterns of a projected database in local.
2626
*/
2727
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
28-
28+
import PrefixSpan._
2929
/**
3030
* Calculate all patterns of a projected database.
3131
* @param minCount minimum count
@@ -43,7 +43,9 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4343
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
4444
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
4545
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
46-
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
46+
val filteredDatabase = database.map { suffix =>
47+
suffix.filter(item => item == DELIMITER || frequentItemAndCounts.contains(item))
48+
}
4749
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
4850
val newPrefixes = item :: prefixes
4951
val newProjected = project(filteredDatabase, item)
@@ -63,7 +65,8 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
6365
if (index == -1) {
6466
Array()
6567
} else {
66-
sequence.drop(index + 1)
68+
// drop until we get to the next delimiter (or end of sequence)
69+
sequence.drop(index).dropWhile(_ != DELIMITER).drop(1)
6770
}
6871
}
6972

@@ -89,6 +92,6 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8992
counts(item) += 1L
9093
}
9194
}
92-
counts.filter(_._2 >= minCount)
95+
counts.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
9396
}
9497
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.storage.StorageLevel
4444
class PrefixSpan private (
4545
private var minSupport: Double,
4646
private var maxPatternLength: Int) extends Logging with Serializable {
47+
import PrefixSpan._
4748

4849
/**
4950
* The maximum number of items allowed in a projected database before local processing. If a
@@ -90,12 +91,14 @@ class PrefixSpan private (
9091

9192
/**
9293
* Find the complete set of sequential patterns in the input sequences.
93-
* @param sequences input data set, contains a set of sequences,
94-
* a sequence is an ordered list of elements.
94+
* @param sequences a dataset of sequences. Items in a sequence are represented by non-negative
95+
* integers and delimited by [[DELIMITER]]. Non-temporal sequences
96+
* are supported by placing more than one item between delimiters.
9597
* @return a set of sequential pattern pairs,
9698
* the key of pair is pattern (a list of elements),
9799
* the value of pair is the pattern's count.
98100
*/
101+
// TODO: generalize to arbitrary item-types and use mapping to Ints for internal algorithm
99102
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100103
val sc = sequences.sparkContext
101104

@@ -110,14 +113,14 @@ class PrefixSpan private (
110113
val freqItemCounts = sequences
111114
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
112115
.reduceByKey(_ + _)
113-
.filter(_._2 >= minCount)
116+
.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
114117
.collect()
115118

116119
// Pairs of (length 1 prefix, suffix consisting of frequent items)
117120
val itemSuffixPairs = {
118121
val freqItems = freqItemCounts.map(_._1).toSet
119122
sequences.flatMap { seq =>
120-
val filteredSeq = seq.filter(freqItems.contains(_))
123+
val filteredSeq = seq.filter(item => item == DELIMITER || freqItems.contains(item))
121124
freqItems.flatMap { item =>
122125
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
123126
candidateSuffix match {
@@ -127,7 +130,6 @@ class PrefixSpan private (
127130
}
128131
}
129132
}
130-
131133
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
132134
// frequent length-one prefixes)
133135
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
@@ -197,7 +199,7 @@ class PrefixSpan private (
197199
val prefixItemPairAndCounts = prefixSuffixPairs
198200
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
199201
.reduceByKey(_ + _)
200-
.filter(_._2 >= minCount)
202+
.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
201203

202204
// Map from prefix to set of possible next items from suffix
203205
val prefixToNextItems = prefixItemPairAndCounts
@@ -247,3 +249,7 @@ class PrefixSpan private (
247249
}
248250
}
249251
}
252+
253+
object PrefixSpan {
254+
val DELIMITER = -1
255+
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
4040
Array(2, 4, 1),
4141
Array(3, 1, 3, 4, 5),
4242
Array(3, 4, 4, 3),
43-
Array(6, 5, 3))
43+
Array(6, 5, 3)).map(insertDelimiter)
4444

4545
val rdd = sc.parallelize(sequences, 2).cache()
4646

@@ -69,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
6969
(Array(4, 5), 2L),
7070
(Array(5), 3L)
7171
)
72-
assert(compareResults(expectedValue1, result1.collect()))
72+
compareResults(expectedValue1, result1.collect())
7373

7474
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
7575
val result2 = prefixspan.run(rdd)
@@ -80,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
8080
(Array(4), 4L),
8181
(Array(5), 3L)
8282
)
83-
assert(compareResults(expectedValue2, result2.collect()))
83+
compareResults(expectedValue2, result2.collect())
8484

8585
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
8686
val result3 = prefixspan.run(rdd)
@@ -100,14 +100,20 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
100100
(Array(4, 5), 2L),
101101
(Array(5), 3L)
102102
)
103-
assert(compareResults(expectedValue3, result3.collect()))
103+
compareResults(expectedValue3, result3.collect())
104104
}
105105

106106
private def compareResults(
107107
expectedValue: Array[(Array[Int], Long)],
108-
actualValue: Array[(Array[Int], Long)]): Boolean = {
109-
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
110-
actualValue.map(x => (x._1.toSeq, x._2)).toSet
108+
actualValue: Array[(Array[Int], Long)]): Unit = {
109+
assert(expectedValue.map(x => (x._1.toSeq, x._2)).toSet ===
110+
actualValue.map(x => (x._1.toSeq, x._2)).toSet)
111+
}
112+
113+
private def insertDelimiter(sequence: Array[Int]): Array[Int] = {
114+
sequence.zip(Seq.fill(sequence.length)(PrefixSpan.DELIMITER)).map { case (a, b) =>
115+
List(a, b)
116+
}.flatten
111117
}
112118

113119
}

0 commit comments

Comments
 (0)