Skip to content

Commit d2a9b66

Browse files
zhangjiajinmengxr
authored andcommitted
[SPARK-8999] [MLLIB] PrefixSpan non-temporal sequences
mengxr Extends PrefixSpan to non-temporal itemsets. Continues work by zhangjiajin * Internal API uses List[Set[Int]] which is likely not efficient; will need to refactor during QA Closes apache#7646 Author: zhangjiajin <[email protected]> Author: Feynman Liang <[email protected]> Author: zhang jiajin <[email protected]> Closes apache#7818 from feynmanliang/SPARK-8999-nonTemporal and squashes the following commits: 4ded81d [Feynman Liang] Replace all filters to filter nonempty 350e67e [Feynman Liang] Code review feedback 03156ca [Feynman Liang] Fix tests, drop delimiters at boundaries of sequences d1fe0ed [Feynman Liang] Remove comments 86ca4e5 [Feynman Liang] Fix style 7c7bf39 [Feynman Liang] Fixed itemSet sequences 6073b10 [Feynman Liang] Basic itemset functionality, failing test 1a7fb48 [Feynman Liang] Add delimiter to results 5db00aa [Feynman Liang] Working for items, not itemsets 6787716 [Feynman Liang] Working on temporal sequences f1114b9 [Feynman Liang] Add -1 delimiter 00fe756 [Feynman Liang] Reset base files for rebase f486dcd [zhangjiajin] change maxLocalProjDBSize and fix a bug (remove -3 from frequent items). 60a0b76 [zhangjiajin] fixed a scala style error. 740c203 [zhangjiajin] fixed a scala style error. 5785cb8 [zhangjiajin] support non-temporal sequence a5d649d [zhangjiajin] restore original version 09dc409 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into multiItems_2 ae8c02d [zhangjiajin] Fixed some Scala style errors. 216ab0c [zhangjiajin] Support non-temporal sequence in PrefixSpan b572f54 [zhangjiajin] initialize file before rebase. f06772f [zhangjiajin] fix a scala style error. a7e50d4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan. c1d13d0 [zhang jiajin] Delete PrefixspanSuite.scala d9d8137 [zhang jiajin] Delete Prefixspan.scala c6ceb63 [zhangjiajin] Add new algorithm PrefixSpan and test file.
1 parent 6503897 commit d2a9b66

File tree

3 files changed

+302
-92
lines changed

3 files changed

+302
-92
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
2525
* Calculate all patterns of a projected database in local.
2626
*/
2727
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
28-
28+
import PrefixSpan._
2929
/**
3030
* Calculate all patterns of a projected database.
3131
* @param minCount minimum count
@@ -39,12 +39,19 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
3939
def run(
4040
minCount: Long,
4141
maxPatternLength: Int,
42-
prefixes: List[Int],
43-
database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
44-
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
45-
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
46-
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
47-
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
42+
prefixes: List[Set[Int]],
43+
database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]], Long)] = {
44+
if (prefixes.length == maxPatternLength || database.isEmpty) {
45+
return Iterator.empty
46+
}
47+
val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database)
48+
val freqItems = freqItemSetsAndCounts.keys.flatten.toSet
49+
val filteredDatabase = database.map { suffix =>
50+
suffix
51+
.map(item => freqItems.intersect(item))
52+
.filter(_.nonEmpty)
53+
}
54+
freqItemSetsAndCounts.iterator.flatMap { case (item, count) =>
4855
val newPrefixes = item :: prefixes
4956
val newProjected = project(filteredDatabase, item)
5057
Iterator.single((newPrefixes, count)) ++
@@ -54,20 +61,23 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
5461

5562
/**
5663
* Calculate suffix sequence immediately after the first occurrence of an item.
57-
* @param item item to get suffix after
64+
* @param item itemset to get suffix after
5865
* @param sequence sequence to extract suffix from
5966
* @return suffix sequence
6067
*/
61-
def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
62-
val index = sequence.indexOf(item)
68+
def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]] = {
69+
val itemsetSeq = sequence
70+
val index = itemsetSeq.indexWhere(item.subsetOf(_))
6371
if (index == -1) {
64-
Array()
72+
List()
6573
} else {
66-
sequence.drop(index + 1)
74+
itemsetSeq.drop(index + 1)
6775
}
6876
}
6977

70-
def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
78+
def project(
79+
database: Iterable[List[Set[Int]]],
80+
prefix: Set[Int]): Iterable[List[Set[Int]]] = {
7181
database
7282
.map(getSuffix(prefix, _))
7383
.filter(_.nonEmpty)
@@ -81,14 +91,16 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
8191
*/
8292
private def getFreqItemAndCounts(
8393
minCount: Long,
84-
database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
94+
database: Iterable[List[Set[Int]]]): Map[Set[Int], Long] = {
8595
// TODO: use PrimitiveKeyOpenHashMap
86-
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
96+
val counts = mutable.Map[Set[Int], Long]().withDefaultValue(0L)
8797
database.foreach { sequence =>
88-
sequence.distinct.foreach { item =>
98+
sequence.flatMap(nonemptySubsets(_)).distinct.foreach { item =>
8999
counts(item) += 1L
90100
}
91101
}
92-
counts.filter(_._2 >= minCount)
102+
counts
103+
.filter { case (_, count) => count >= minCount }
104+
.toMap
93105
}
94106
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.fpm
1919

20-
import scala.collection.mutable.ArrayBuffer
20+
import scala.collection.mutable.ArrayBuilder
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.Experimental
@@ -44,13 +44,14 @@ import org.apache.spark.storage.StorageLevel
4444
class PrefixSpan private (
4545
private var minSupport: Double,
4646
private var maxPatternLength: Int) extends Logging with Serializable {
47+
import PrefixSpan._
4748

4849
/**
4950
* The maximum number of items allowed in a projected database before local processing. If a
5051
* projected database exceeds this size, another iteration of distributed PrefixSpan is run.
5152
*/
52-
// TODO: make configurable with a better default value, 10000 may be too small
53-
private val maxLocalProjDBSize: Long = 10000
53+
// TODO: make configurable with a better default value
54+
private val maxLocalProjDBSize: Long = 32000000L
5455

5556
/**
5657
* Constructs a default instance with default parameters
@@ -90,35 +91,41 @@ class PrefixSpan private (
9091

9192
/**
9293
* Find the complete set of sequential patterns in the input sequences.
93-
* @param sequences input data set, contains a set of sequences,
94-
* a sequence is an ordered list of elements.
94+
* @param data ordered sequences of itemsets. Items are represented by non-negative integers.
95+
* Each itemset has one or more items and is delimited by [[DELIMITER]].
9596
* @return a set of sequential pattern pairs,
9697
* the key of pair is pattern (a list of elements),
9798
* the value of pair is the pattern's count.
9899
*/
99-
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
100-
val sc = sequences.sparkContext
100+
// TODO: generalize to arbitrary item-types and use mapping to Ints for internal algorithm
101+
def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
102+
val sc = data.sparkContext
101103

102-
if (sequences.getStorageLevel == StorageLevel.NONE) {
104+
if (data.getStorageLevel == StorageLevel.NONE) {
103105
logWarning("Input data is not cached.")
104106
}
105107

108+
// Use List[Set[Item]] for internal computation
109+
val sequences = data.map { seq => splitSequence(seq.toList) }
110+
106111
// Convert min support to a min number of transactions for this dataset
107112
val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
108113

109114
// (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
110115
val freqItemCounts = sequences
111-
.flatMap(seq => seq.distinct.map(item => (item, 1L)))
116+
.flatMap(seq => seq.flatMap(nonemptySubsets(_)).distinct.map(item => (item, 1L)))
112117
.reduceByKey(_ + _)
113-
.filter(_._2 >= minCount)
118+
.filter { case (item, count) => (count >= minCount) }
114119
.collect()
120+
.toMap
115121

116122
// Pairs of (length 1 prefix, suffix consisting of frequent items)
117123
val itemSuffixPairs = {
118-
val freqItems = freqItemCounts.map(_._1).toSet
124+
val freqItemSets = freqItemCounts.keys.toSet
125+
val freqItems = freqItemSets.flatten
119126
sequences.flatMap { seq =>
120-
val filteredSeq = seq.filter(freqItems.contains(_))
121-
freqItems.flatMap { item =>
127+
val filteredSeq = seq.map(item => freqItems.intersect(item)).filter(_.nonEmpty)
128+
freqItemSets.flatMap { item =>
122129
val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
123130
candidateSuffix match {
124131
case suffix if !suffix.isEmpty => Some((List(item), suffix))
@@ -130,14 +137,15 @@ class PrefixSpan private (
130137

131138
// Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
132139
// frequent length-one prefixes)
133-
var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
140+
var resultsAccumulator = freqItemCounts.map { case (item, count) => (List(item), count) }.toList
134141

135142
// Remaining work to be locally and distributively processed respectfully
136143
var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
137144

138145
// Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
139-
// projected database sizes <= `maxLocalProjDBSize`)
140-
while (pairsForDistributed.count() != 0) {
146+
// projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached
147+
var patternLength = 1
148+
while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) {
141149
val (nextPatternAndCounts, nextPrefixSuffixPairs) =
142150
extendPrefixes(minCount, pairsForDistributed)
143151
pairsForDistributed.unpersist()
@@ -146,14 +154,15 @@ class PrefixSpan private (
146154
pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
147155
pairsForLocal ++= smallerPairsPart
148156
resultsAccumulator ++= nextPatternAndCounts.collect()
157+
patternLength += 1 // pattern length grows one per iteration
149158
}
150159

151160
// Process the small projected databases locally
152161
val remainingResults = getPatternsInLocal(
153162
minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
154163

155164
(sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
156-
.map { case (pattern, count) => (pattern.toArray, count) }
165+
.map { case (pattern, count) => (flattenSequence(pattern.reverse).toArray, count) }
157166
}
158167

159168

@@ -163,8 +172,8 @@ class PrefixSpan private (
163172
* @return prefix-suffix pairs partitioned by whether their projected database size is <= or
164173
* greater than [[maxLocalProjDBSize]]
165174
*/
166-
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
167-
: (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
175+
private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])])
176+
: (List[(List[Set[Int]], List[Set[Int]])], RDD[(List[Set[Int]], List[Set[Int]])]) = {
168177
val prefixToSuffixSize = prefixSuffixPairs
169178
.aggregateByKey(0)(
170179
seqOp = { case (count, suffix) => count + suffix.length },
@@ -176,28 +185,29 @@ class PrefixSpan private (
176185
.toSet
177186
val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
178187
val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
179-
(small.collect(), large)
188+
(small.collect().toList, large)
180189
}
181190

182191
/**
183-
* Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
184-
* and remaining work.
192+
* Extends all prefixes by one itemset from their suffix and computes the resulting frequent
193+
* prefixes and remaining work.
185194
* @param minCount minimum count
186195
* @param prefixSuffixPairs prefix (length N) and suffix pairs,
187196
* @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
188197
* prefix, corresponding suffix) pairs.
189198
*/
190199
private def extendPrefixes(
191200
minCount: Long,
192-
prefixSuffixPairs: RDD[(List[Int], Array[Int])])
193-
: (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
201+
prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])])
202+
: (RDD[(List[Set[Int]], Long)], RDD[(List[Set[Int]], List[Set[Int]])]) = {
194203

195-
// (length N prefix, item from suffix) pairs and their corresponding number of occurrences
204+
// (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences
196205
// Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
197206
val prefixItemPairAndCounts = prefixSuffixPairs
198-
.flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
207+
.flatMap { case (prefix, suffix) =>
208+
suffix.flatMap(nonemptySubsets(_)).distinct.map(y => ((prefix, y), 1L)) }
199209
.reduceByKey(_ + _)
200-
.filter(_._2 >= minCount)
210+
.filter { case (item, count) => (count >= minCount) }
201211

202212
// Map from prefix to set of possible next items from suffix
203213
val prefixToNextItems = prefixItemPairAndCounts
@@ -207,7 +217,6 @@ class PrefixSpan private (
207217
.collect()
208218
.toMap
209219

210-
211220
// Frequent patterns with length N+1 and their corresponding counts
212221
val extendedPrefixAndCounts = prefixItemPairAndCounts
213222
.map { case ((prefix, item), count) => (item :: prefix, count) }
@@ -216,9 +225,12 @@ class PrefixSpan private (
216225
val extendedPrefixAndSuffix = prefixSuffixPairs
217226
.filter(x => prefixToNextItems.contains(x._1))
218227
.flatMap { case (prefix, suffix) =>
219-
val frequentNextItems = prefixToNextItems(prefix)
220-
val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
221-
frequentNextItems.flatMap { item =>
228+
val frequentNextItemSets = prefixToNextItems(prefix)
229+
val frequentNextItems = frequentNextItemSets.flatten
230+
val filteredSuffix = suffix
231+
.map(item => frequentNextItems.intersect(item))
232+
.filter(_.nonEmpty)
233+
frequentNextItemSets.flatMap { item =>
222234
LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
223235
case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
224236
case _ => None
@@ -237,13 +249,38 @@ class PrefixSpan private (
237249
*/
238250
private def getPatternsInLocal(
239251
minCount: Long,
240-
data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
252+
data: RDD[(List[Set[Int]], Iterable[List[Set[Int]]])]): RDD[(List[Set[Int]], Long)] = {
241253
data.flatMap {
242-
case (prefix, projDB) =>
243-
LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
244-
.map { case (pattern: List[Int], count: Long) =>
245-
(pattern.reverse, count)
246-
}
254+
case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB)
255+
}
256+
}
257+
258+
}
259+
260+
private[fpm] object PrefixSpan {
261+
private[fpm] val DELIMITER = -1
262+
263+
/** Splits a sequence of itemsets delimited by [[DELIMITER]]. */
264+
private[fpm] def splitSequence(sequence: List[Int]): List[Set[Int]] = {
265+
sequence.span(_ != DELIMITER) match {
266+
case (x, xs) if xs.length > 1 => x.toSet :: splitSequence(xs.tail)
267+
case (x, xs) => List(x.toSet)
268+
}
269+
}
270+
271+
/** Flattens a sequence of itemsets into an Array, inserting[[DELIMITER]] between itemsets. */
272+
private[fpm] def flattenSequence(sequence: List[Set[Int]]): List[Int] = {
273+
val builder = ArrayBuilder.make[Int]()
274+
for (itemSet <- sequence) {
275+
builder += DELIMITER
276+
builder ++= itemSet.toSeq.sorted
247277
}
278+
builder.result().toList.drop(1) // drop trailing delimiter
279+
}
280+
281+
/** Returns an iterator over all non-empty subsets of `itemSet` */
282+
private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = {
283+
// TODO: improve complexity by using partial prefixes, considering one item at a time
284+
itemSet.subsets.filter(_ != Set.empty[Int])
248285
}
249286
}

0 commit comments

Comments
 (0)