Skip to content

Commit 1a7fb48

Browse files
author
Feynman Liang
committed
Add delimiter to results
1 parent 5db00aa commit 1a7fb48

File tree

3 files changed

+219
-191
lines changed

3 files changed

+219
-191
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,17 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
4040
minCount: Long,
4141
maxPatternLength: Int,
4242
prefixes: List[Int],
43-
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
43+
database: Iterable[List[Int]]): Iterator[(List[Int], Long)] = {
4444
if (prefixes.count(_ == DELIMITER) == maxPatternLength || database.isEmpty) {
4545
return Iterator.empty
4646
}
4747
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
4848
val filteredDatabase = database.map { suffix =>
49-
suffix.filter(item => item == DELIMITER || frequentItemAndCounts.contains(item))
49+
insertDelimiters(
50+
splitAtDelimiter(suffix).filter(item => frequentItemAndCounts.contains(item)))
5051
}
5152
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
52-
val newPrefixes = DELIMITER :: item :: prefixes
53+
val newPrefixes = DELIMITER :: item ::: prefixes
5354
val newProjected = project(filteredDatabase, item)
5455
Iterator.single((newPrefixes, count)) ++
5556
run(minCount, maxPatternLength, newPrefixes, newProjected)
@@ -58,21 +59,21 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
5859

5960
/**
6061
* Calculate suffix sequence immediately after the first occurrence of an item.
61-
* @param item item to get suffix after
62+
* @param item itemset to get suffix after
6263
* @param sequence sequence to extract suffix from
6364
* @return suffix sequence
6465
*/
65-
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
66-
val index = sequence.indexOf(item)
66+
def getSuffix(item: List[Int], sequence: List[Int]): List[Int] = {
67+
val itemsetSeq = splitAtDelimiter(sequence)
68+
val index = itemsetSeq.indexOf(item)
6769
if (index == -1) {
68-
Array()
70+
List()
6971
} else {
70-
// in case index is inside an itemset, drop until we get to the next delimiter (or end of seq)
71-
sequence.drop(index).dropWhile(_ != DELIMITER).drop(1)
72+
insertDelimiters(itemsetSeq.drop(index+1))
7273
}
7374
}
7475

75-
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
76+
def project(database: Iterable[List[Int]], prefix: List[Int]): Iterable[List[Int]] = {
7677
database
7778
.map(getSuffix(prefix, _))
7879
.filter(_.nonEmpty)
@@ -86,14 +87,16 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8687
*/
8788
private def getFreqItemAndCounts(
8889
minCount: Long,
89-
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
90+
database: Iterable[List[Int]]): Map[List[Int], Long] = {
9091
// TODO: use PrimitiveKeyOpenHashMap
91-
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
92+
val counts = mutable.Map[List[Int], Long]().withDefaultValue(0L)
9293
database.foreach { sequence =>
93-
sequence.distinct.foreach { item =>
94+
splitAtDelimiter(sequence).distinct.foreach { item =>
9495
counts(item) += 1L
9596
}
9697
}
97-
counts.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
98+
counts
99+
.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
100+
.toMap
98101
}
99102
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -89,55 +89,58 @@ class PrefixSpan private (
8989

9090
/**
9191
* Find the complete set of sequential patterns in the input sequences.
92-
* @param sequences ordered sequences of itemsets. Items are represented by non-negative integers.
92+
* @param data ordered sequences of itemsets. Items are represented by non-negative integers.
9393
* Each itemset has one or more items and is delimited by [[DELIMITER]].
9494
* @return a set of sequential pattern pairs,
9595
* the key of pair is pattern (a list of elements),
9696
* the value of pair is the pattern's count.
9797
*/
9898
// TODO: generalize to arbitrary item-types and use mapping to Ints for internal algorithm
99-
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100-
val sc = sequences.sparkContext
99+
def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100+
val sc = data.sparkContext
101101

102-
if (sequences.getStorageLevel == StorageLevel.NONE) {
102+
if (data.getStorageLevel == StorageLevel.NONE) {
103103
logWarning("Input data is not cached.")
104104
}
105+
val sequences = data.map(_.toList)
105106

106107
// Convert min support to a min number of transactions for this dataset
107108
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
108109

109110
// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
110111
val freqItemCounts = sequences
111-
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
112+
.flatMap(seq => splitAtDelimiter(seq).distinct.map(item => (item, 1L)))
112113
.reduceByKey(_ + _)
113-
.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
114+
.filter { case (item, count) => (count >= minCount) }
114115
.collect()
116+
.toList
115117

116118
// Pairs of (length 1 prefix, suffix consisting of frequent items)
117119
val itemSuffixPairs = {
118120
val freqItems = freqItemCounts.map(_._1).toSet
119121
sequences.flatMap { seq =>
120-
val filteredSeq = seq.filter(item => item == DELIMITER || freqItems.contains(item))
122+
val filteredSeq = insertDelimiters(
123+
splitAtDelimiter(seq).filter(item => freqItems.contains(item)))
121124
freqItems.flatMap { item =>
122125
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
123126
candidateSuffix match {
124-
case suffix if !suffix.isEmpty => Some((List(DELIMITER, item), suffix))
127+
case suffix if !suffix.isEmpty => Some((DELIMITER :: item, suffix))
125128
case _ => None
126129
}
127130
}
128131
}
129132
}
130133
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
131134
// frequent length-one prefixes)
132-
var resultsAccumulator = freqItemCounts.map(x => (List(DELIMITER, x._1), x._2))
135+
var resultsAccumulator = freqItemCounts.map(x => (DELIMITER :: x._1, x._2))
133136

134137
// Remaining work to be locally and distributively processed respectfully
135138
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
136139

137140
// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
138141
// projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached
139142
var patternLength = 1
140-
while (pairsForDistributed.count() != 0 || patternLength < maxPatternLength) {
143+
while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) {
141144
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
142145
extendPrefixes(minCount, pairsForDistributed)
143146
pairsForDistributed.unpersist()
@@ -164,8 +167,8 @@ class PrefixSpan private (
164167
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
165168
* greater than [[maxLocalProjDBSize]]
166169
*/
167-
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
168-
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
170+
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], List[Int])])
171+
: (List[(List[Int], List[Int])], RDD[(List[Int], List[Int])]) = {
169172
val prefixToSuffixSize = prefixSuffixPairs
170173
.aggregateByKey(0)(
171174
seqOp = { case (count, suffix) => count + suffix.length },
@@ -177,28 +180,30 @@ class PrefixSpan private (
177180
.toSet
178181
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
179182
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
180-
(small.collect(), large)
183+
(small.collect().toList, large)
181184
}
182185

183186
/**
184-
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
185-
* and remaining work.
187+
* Extends all prefixes by one itemset from their suffix and computes the resulting frequent
188+
* prefixes and remaining work.
186189
* @param minCount minimum count
187190
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
188191
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
189192
* prefix, corresponding suffix) pairs.
190193
*/
191194
private def extendPrefixes(
192195
minCount: Long,
193-
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
194-
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
196+
prefixSuffixPairs: RDD[(List[Int], List[Int])])
197+
: (RDD[(List[Int], Long)], RDD[(List[Int], List[Int])]) = {
195198

196-
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
199+
// (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences
197200
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
198201
val prefixItemPairAndCounts = prefixSuffixPairs
199-
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
202+
.flatMap { case (prefix, suffix) =>
203+
splitAtDelimiter(suffix).distinct.map(y => ((prefix, y), 1L))
204+
}
200205
.reduceByKey(_ + _)
201-
.filter { case (item, count) => (count >= minCount) && (item != DELIMITER) }
206+
.filter { case (item, count) => (count >= minCount) }
202207

203208
// Map from prefix to set of possible next items from suffix
204209
val prefixToNextItems = prefixItemPairAndCounts
@@ -210,17 +215,18 @@ class PrefixSpan private (
210215

211216
// Frequent patterns with length N+1 and their corresponding counts
212217
val extendedPrefixAndCounts = prefixItemPairAndCounts
213-
.map { case ((prefix, item), count) => (DELIMITER :: item :: prefix, count) }
218+
.map { case ((prefix, item), count) => (DELIMITER :: item ::: prefix, count) }
214219

215220
// Remaining work, all prefixes will have length N+1
216221
val extendedPrefixAndSuffix = prefixSuffixPairs
217222
.filter(x => prefixToNextItems.contains(x._1))
218223
.flatMap { case (prefix, suffix) =>
219224
val frequentNextItems = prefixToNextItems(prefix)
220-
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
225+
val filteredSuffix = insertDelimiters(
226+
splitAtDelimiter(suffix).filter(frequentNextItems.contains(_)))
221227
frequentNextItems.flatMap { item =>
222228
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
223-
case suffix if !suffix.isEmpty => Some(DELIMITER :: item :: prefix, suffix)
229+
case suffix if !suffix.isEmpty => Some(DELIMITER :: item ::: prefix, suffix)
224230
case _ => None
225231
}
226232
}
@@ -237,17 +243,33 @@ class PrefixSpan private (
237243
*/
238244
private def getPatternsInLocal(
239245
minCount: Long,
240-
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
246+
data: RDD[(List[Int], Iterable[List[Int]])]): RDD[(List[Int], Long)] = {
241247
data.flatMap {
242248
case (prefix, projDB) =>
243-
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
249+
LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB)
244250
.map { case (pattern: List[Int], count: Long) =>
245251
(pattern, count)
246252
}
247253
}
248254
}
255+
249256
}
250257

251258
object PrefixSpan {
252-
val DELIMITER = -1
259+
private[fpm] val DELIMITER = -1
260+
261+
private[fpm] def splitAtDelimiter(pattern: List[Int]): List[List[Int]] = {
262+
pattern.span(_ != DELIMITER) match {
263+
case (x, xs) if xs.length > 1 => x :: splitAtDelimiter(xs.tail)
264+
case (x, xs) => List(x)
265+
}
266+
}
267+
268+
private[fpm] def insertDelimiters(sequence: List[List[Int]]): List[Int] = {
269+
// TODO: avoid allocating new arrays when appending
270+
sequence.zip(Seq.fill(sequence.size)(PrefixSpan.DELIMITER))
271+
.flatMap { case (a: List[Int], b: Int) =>
272+
a :+ b
273+
}
274+
}
253275
}

0 commit comments

Comments
 (0)